diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b66bba46..a905ad3e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,7 +2,7 @@ name: CI on: push: - branches: [ master ] + branches: [ master, v5-dev ] pull_request: branches: [ master ] @@ -14,21 +14,52 @@ jobs: strategy: matrix: - go-version: [1.16, 1.17] + go-version: [1.18] pg-version: [10, 11, 12, 13, 14, cockroachdb] include: - pg-version: 10 pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" + pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require + pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test - pg-version: 11 pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" + pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require + pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test - pg-version: 12 pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" + pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require + pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test - pg-version: 13 pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" + pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require + pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test - pg-version: 14 pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" + pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require + pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test - pg-version: cockroachdb pgx-test-database: "postgresql://root@127.0.0.1:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on" + pgx-test-conn-string: "postgresql://root@127.0.0.1:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on" steps: @@ -49,3 +80,9 @@ jobs: run: go test -race ./... env: PGX_TEST_DATABASE: ${{ matrix.pgx-test-database }} + PGX_TEST_CONN_STRING: ${{ matrix.pgx-test-conn-string }} + PGX_TEST_UNIX_SOCKET_CONN_STRING: ${{ matrix.pgx-test-unix-socket-conn-string }} + PGX_TEST_TCP_CONN_STRING: ${{ matrix.pgx-test-tcp-conn-string }} + PGX_TEST_TLS_CONN_STRING: ${{ matrix.pgx-test-tls-conn-string }} + PGX_TEST_MD5_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-md5-password-conn-string }} + PGX_TEST_PLAIN_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-plain-password-conn-string }} diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 00000000..e176228e --- /dev/null +++ b/.travis.yml @@ -0,0 +1,9 @@ +language: go + +go: + - 1.x + - tip + +matrix: + allow_failures: + - go: tip diff --git a/CHANGELOG.md b/CHANGELOG.md index e8f20129..32acfdda 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,268 +1,177 @@ -# 4.17.2 (September 3, 2022) +# v5.0.0 -* Fix panic when logging batch error (Tom Möller) +## Merged Packages -# 4.17.1 (August 27, 2022) +`github.com/jackc/pgtype`, `github.com/jackc/pgconn`, and `github.com/jackc/pgproto3` are now included in the main +`github.com/jackc/pgx` repository. Previously there was confusion as to where issues should be reported, additional +release work due to releasing multiple packages, and less clear changelogs. -* Upgrade puddle to v1.3.0 - fixes context failing to cancel Acquire when acquire is creating resource which was introduced in v4.17.0 (James Hartig) -* Fix atomic alignment on 32-bit platforms +## pgconn -# 4.17.0 (August 6, 2022) +`CommandTag` is now an opaque type instead of directly exposing an underlying `[]byte`. -* Upgrade pgconn to v1.13.0 -* Upgrade pgproto3 to v2.3.1 -* Upgrade pgtype to v1.12.0 -* Allow background pool connections to continue even if cause is canceled (James Hartig) -* Add LoggerFunc (Gabor Szabad) -* pgxpool: health check should avoid going below minConns (James Hartig) -* Add pgxpool.Conn.Hijack() -* Logging improvements (Stepan Rabotkin) +The return value `ResultReader.Values()` is no longer safe to retain a reference to after a subsequent call to `NextRow()` or `Close()`. -# 4.16.1 (May 7, 2022) +`Trace()` method adds low level message tracing similar to the `PQtrace` function in `libpq`. -* Upgrade pgconn to v1.12.1 -* Fix explicitly prepared statements with describe statement cache mode +pgconn now uses non-blocking IO. This is a significant internal restructuring, but it should not cause any visible changes on its own. However, it is important in implementing other new features. -# 4.16.0 (April 21, 2022) +`CheckConn()` checks a connection's liveness by doing a non-blocking read. This can be used to detect database restarts or network interruptions without executing a query or a ping. -* Upgrade pgconn to v1.12.0 -* Upgrade pgproto3 to v2.3.0 -* Upgrade pgtype to v1.11.0 -* Fix: Do not panic when context cancelled while getting statement from cache. -* Fix: Less memory pinning from old Rows. -* Fix: Support '\r' line ending when sanitizing SQL comment. -* Add pluggable GSSAPI support (Oliver Tan) +pgconn now supports pipeline mode. -# 4.15.0 (February 7, 2022) +`*PgConn.ReceiveResults` removed. Use pipeline mode instead. -* Upgrade to pgconn v1.11.0 -* Upgrade to pgtype v1.10.0 -* Upgrade puddle to v1.2.1 -* Make BatchResults.Close safe to be called multiple times +`Timeout()` no longer considers `context.Canceled` as a timeout error. `context.DeadlineExceeded` still is considered a timeout error. -# 4.14.1 (November 28, 2021) +## pgxpool -* Upgrade pgtype to v1.9.1 (fixes unintentional change to timestamp binary decoding) -* Start pgxpool background health check after initial connections +`Connect` and `ConnectConfig` have been renamed to `New` and `NewWithConfig` respectively. The `LazyConnect` option has been removed. Pools always lazily connect. -# 4.14.0 (November 20, 2021) +## pgtype -* Upgrade pgconn to v1.10.1 -* Upgrade pgproto3 to v2.2.0 -* Upgrade pgtype to v1.9.0 -* Upgrade puddle to v1.2.0 -* Add QueryFunc to BatchResults -* Add context options to zerologadapter (Thomas Frössman) -* Add zerologadapter.NewContextLogger (urso) -* Eager initialize minpoolsize on connect (Daniel) -* Unpin memory used by large queries immediately after use +The `pgtype` package has been significantly changed. -# 4.13.0 (July 24, 2021) +### NULL Representation -* Trimmed pseudo-dependencies in Go modules from other packages tests -* Upgrade pgconn -- context cancellation no longer will return a net.Error -* Support time durations for simple protocol (Michael Darr) +Previously, types had a `Status` field that could be `Undefined`, `Null`, or `Present`. This has been changed to a +`Valid` `bool` field to harmonize with how `database/sql` represents `NULL` and to make the zero value useable. -# 4.12.0 (July 10, 2021) +### Codec and Value Split -* ResetSession hook is called before a connection is reused from pool for another query (Dmytro Haranzha) -* stdlib: Add RandomizeHostOrderFunc (dkinder) -* stdlib: add OptionBeforeConnect (dkinder) -* stdlib: Do not reuse ConnConfig strings (Andrew Kimball) -* stdlib: implement Conn.ResetSession (Jonathan Amsterdam) -* Upgrade pgconn to v1.9.0 -* Upgrade pgtype to v1.8.0 +Previously, the type system combined decoding and encoding values with the value types. e.g. Type `Int8` both handled +encoding and decoding the PostgreSQL representation and acted as a value object. This caused some difficulties when +there was not an exact 1 to 1 relationship between the Go types and the PostgreSQL types For example, scanning a +PostgreSQL binary `numeric` into a Go `float64` was awkward (see https://github.com/jackc/pgtype/issues/147). This +concepts have been separated. A `Codec` only has responsibility for encoding and decoding values. Value types are +generally defined by implementing an interface that a particular `Codec` understands (e.g. `PointScanner` and +`PointValuer` for the PostgreSQL `point` type). -# 4.11.0 (March 25, 2021) +### Array Types -* Add BeforeConnect callback to pgxpool.Config (Robert Froehlich) -* Add Ping method to pgxpool.Conn (davidsbond) -* Added a kitlog level log adapter (Fabrice Aneche) -* Make ScanArgError public to allow identification of offending column (Pau Sanchez) -* Add *pgxpool.AcquireFunc -* Add BeginFunc and BeginTxFunc -* Add prefer_simple_protocol to connection string -* Add logging on CopyFrom (Patrick Hemmer) -* Add comment support when sanitizing SQL queries (Rusakow Andrew) -* Do not panic on double close of pgxpool.Pool (Matt Schultz) -* Avoid panic on SendBatch on closed Tx (Matt Schultz) -* Update pgconn to v1.8.1 -* Update pgtype to v1.7.0 +All array types are now handled by `ArrayCodec` instead of using code generation for each new array type. This also +means that less common array types such as `point[]` are now supported. `Array[T]` supports PostgreSQL multi-dimensional +arrays. -# 4.10.1 (December 19, 2020) +### Composite Types -* Fix panic on Query error with nil stmtcache. +Composite types must be registered before use. `CompositeFields` may still be used to construct and destruct composite +values, but any type may now implement `CompositeIndexGetter` and `CompositeIndexScanner` to be used as a composite. -# 4.10.0 (December 3, 2020) +### Range Types -* Add CopyFromSlice to simplify CopyFrom usage (Egon Elbre) -* Remove broken prepared statements from stmtcache (Ethan Pailes) -* stdlib: consider any Ping error as fatal -* Update puddle to v1.1.3 - this fixes an issue where concurrent Acquires can hang when a connection cannot be established -* Update pgtype to v1.6.2 +Range types are now handled with types `RangeCodec` and `Range[T]`. This allows additional user defined range types to +easily be handled. Multirange types are handled similarly with `MultirangeCodec` and `Multirange[T]`. -# 4.9.2 (November 3, 2020) +### pgxtype -The underlying library updates fix an issue where appending to a scanned slice could corrupt other data. +`LoadDataType` moved to `*Conn` as `LoadType`. -* Update pgconn to v1.7.2 -* Update pgproto3 to v2.0.6 +### Bytea -# 4.9.1 (October 31, 2020) +The `Bytea` and `GenericBinary` types have been replaced. Use the following instead: -* Update pgconn to v1.7.1 -* Update pgtype to v1.6.1 -* Fix SendBatch of all prepared statements with statement cache disabled +* `[]byte` - For normal usage directly use `[]byte`. +* `DriverBytes` - Uses driver memory only available until next database method call. Avoids a copy and an allocation. +* `PreallocBytes` - Uses preallocated byte slice to avoid an allocation. +* `UndecodedBytes` - Avoids any decoding. Allows working with raw bytes. -# 4.9.0 (September 26, 2020) +### Dropped lib/pq Support -* pgxpool now waits for connection cleanup to finish before making room in pool for another connection. This prevents temporarily exceeding max pool size. -* Fix when scanning a column to nil to skip it on the first row but scanning it to a real value on a subsequent row. -* Fix prefer simple protocol with prepared statements. (Jinzhu) -* Fix FieldDescriptions not being available on Rows before calling Next the first time. -* Various minor fixes in updated versions of pgconn, pgtype, and puddle. +`pgtype` previously supported and was tested against [lib/pq](https://github.com/lib/pq). While it will continue to work +in most cases this is no longer supported. -# 4.8.1 (July 29, 2020) +### database/sql Scan -* Update pgconn to v1.6.4 - * Fix deadlock on error after CommandComplete but before ReadyForQuery - * Fix panic on parsing DSN with trailing '=' +Previously, most `Scan` implementations would convert `[]byte` to `string` automatically to decode a text value. Now +only `string` is handled. This is to allow the possibility of future binary support in `database/sql` mode by +considering `[]byte` to be binary format and `string` text format. This change should have no effect for any use with +`pgx`. The previous behavior was only necessary for `lib/pq` compatibility. -# 4.8.0 (July 22, 2020) +Added `*Map.SQLScanner` to create a `sql.Scanner` for types such as `[]int32` and `Range[T]` that do not implement +`sql.Scanner` directly. -* All argument types supported by native pgx should now also work through database/sql -* Update pgconn to v1.6.3 -* Update pgtype to v1.4.2 +### Number Type Fields Include Bit size -# 4.7.2 (July 14, 2020) +`Int2`, `Int4`, `Int8`, `Float4`, `Float8`, and `Uint32` fields now include bit size. e.g. `Int` is renamed to `Int64`. +This matches the convention set by `database/sql`. In addition, for comparable types like `pgtype.Int8` and +`sql.NullInt64` the structures are identical. This means they can be directly converted one to another. -* Improve performance of Columns() (zikaeroh) -* Fix fatal Commit() failure not being considered fatal -* Update pgconn to v1.6.2 -* Update pgtype to v1.4.1 +### 3rd Party Type Integrations -# 4.7.1 (June 29, 2020) +* Extracted integrations with https://github.com/shopspring/decimal and https://github.com/gofrs/uuid to + https://github.com/jackc/pgx-shopspring-decimal and https://github.com/jackc/pgx-gofrs-uuid respectively. This trims + the pgx dependency tree. -* Fix stdlib decoding error with certain order and combination of fields +### Other Changes -# 4.7.0 (June 27, 2020) +* `Bit` and `Varbit` are both replaced by the `Bits` type. +* `CID`, `OID`, `OIDValue`, and `XID` are replaced by the `Uint32` type. +* `Hstore` is now defined as `map[string]*string`. +* `JSON` and `JSONB` types removed. Use `[]byte` or `string` directly. +* `QChar` type removed. Use `rune` or `byte` directly. +* `Inet` and `Cidr` types removed. Use `netip.Addr` and `netip.Prefix` directly. These types are more memory efficient than the previous `net.IPNet`. +* `Macaddr` type removed. Use `net.HardwareAddr` directly. +* Renamed `pgtype.ConnInfo` to `pgtype.Map`. +* Renamed `pgtype.DataType` to `pgtype.Type`. +* Renamed `pgtype.None` to `pgtype.Finite`. +* `RegisterType` now accepts a `*Type` instead of `Type`. +* Assorted array helper methods and types made private. -* Update pgtype to v1.4.0 -* Update pgconn to v1.6.1 -* Update puddle to v1.1.1 -* Fix context propagation with Tx commit and Rollback (georgysavva) -* Add lazy connect option to pgxpool (georgysavva) -* Fix connection leak if pgxpool.BeginTx() fail (Jean-Baptiste Bronisz) -* Add native Go slice support for strings and numbers to simple protocol -* stdlib add default timeouts for Conn.Close() and Stmt.Close() (georgysavva) -* Assorted performance improvements especially with large result sets -* Fix close pool on not lazy connect failure (Yegor Myskin) -* Add Config copy (georgysavva) -* Support SendBatch with Simple Protocol (Jordan Lewis) -* Better error logging on rows close (Igor V. Kozinov) -* Expose stdlib.Conn.Conn() to enable database/sql.Conn.Raw() -* Improve unknown type support for database/sql -* Fix transaction commit failure closing connection +## stdlib -# 4.6.0 (March 30, 2020) +* Removed `AcquireConn` and `ReleaseConn` as that functionality has been built in since Go 1.13. -* stdlib: Bail early if preloading rows.Next() results in rows.Err() (Bas van Beek) -* Sanitize time to microsecond accuracy (Andrew Nicoll) -* Update pgtype to v1.3.0 -* Update pgconn to v1.5.0 - * Update golang.org/x/crypto for security fix - * Implement "verify-ca" SSL mode +## Reduced Memory Usage by Reusing Read Buffers -# 4.5.0 (March 7, 2020) +Previously, the connection read buffer would allocate large chunks of memory and never reuse them. This allowed +transferring ownership to anything such as scanned values without incurring an additional allocation and memory copy. +However, this came at the cost of overall increased memory allocation size. But worse it was also possible to pin large +chunks of memory by retaining a reference to a small value that originally came directly from the read buffer. Now +ownership remains with the read buffer and anything needing to retain a value must make a copy. -* Update to pgconn v1.4.0 - * Fixes QueryRow with empty SQL - * Adds PostgreSQL service file support -* Add Len() to *pgx.Batch (WGH) -* Better logging for individual batch items (Ben Bader) +## Query Execution Modes -# 4.4.1 (February 14, 2020) +Control over automatic prepared statement caching and simple protocol use are now combined into query execution mode. +See documentation for `QueryExecMode`. -* Update pgconn to v1.3.2 - better default read buffer size -* Fix race in CopyFrom +## QueryRewriter Interface and NamedArgs -# 4.4.0 (February 5, 2020) +pgx now supports named arguments with the `NamedArgs` type. This is implemented via the new `QueryRewriter` interface which +allows arbitrary rewriting of query SQL and arguments. -* Update puddle to v1.1.0 - fixes possible deadlock when acquire is cancelled -* Update pgconn to v1.3.1 - fixes CopyFrom deadlock when multiple NoticeResponse received during copy -* Update pgtype to v1.2.0 -* Add MaxConnIdleTime to pgxpool (Patrick Ellul) -* Add MinConns to pgxpool (Patrick Ellul) -* Fix: stdlib.ReleaseConn closes connections left in invalid state +## RowScanner Interface -# 4.3.0 (January 23, 2020) +The `RowScanner` interface allows a single argument to Rows.Scan to scan the entire row. -* Fix Rows.Values panic when unable to decode -* Add Rows.Values support for unknown types -* Add DriverContext support for stdlib (Alex Gaynor) -* Update pgproto3 to v2.0.1 to never return an io.EOF as it would be misinterpreted by database/sql. Instead return io.UnexpectedEOF. +## Rows Result Helpers -# 4.2.1 (January 13, 2020) +* `CollectRows` and `RowTo*` functions simplify collecting results into a slice. +* `CollectOneRow` collects one row using `RowTo*` functions. +* `ForEachRow` simplifies scanning each row and executing code using the scanned values. `ForEachRow` replaces `QueryFunc`. -* Update pgconn to v1.2.1 (fixes context cancellation data race introduced in v1.2.0)) +## Tx Helpers -# 4.2.0 (January 11, 2020) +Rather than every type that implemented `Begin` or `BeginTx` methods also needing to implement `BeginFunc` and +`BeginTxFunc` these methods have been converted to functions that take a db that implements `Begin` or `BeginTx`. -* Update pgconn to v1.2.0. -* Update pgtype to v1.1.0. -* Return error instead of panic when wrong number of arguments passed to Exec. (malstoun) -* Fix large objects functionality when PreferSimpleProtocol = true. -* Restore GetDefaultDriver which existed in v3. (Johan Brandhorst) -* Add RegisterConnConfig to stdlib which replaces the removed RegisterDriverConfig from v3. +## Improved Batch Query Ergonomics -# 4.1.2 (October 22, 2019) +Previously, the code for building a batch went in one place before the call to `SendBatch`, and the code for reading the +results went in one place after the call to `SendBatch`. This could make it difficult to match up the query and the code +to handle the results. Now `Queue` returns a `QueuedQuery` which has methods `Query`, `QueryRow`, and `Exec` which can +be used to register a callback function that will handle the result. Callback functions are called automatically when +`BatchResults.Close` is called. -* Fix dbSavepoint.Begin recursive self call -* Upgrade pgtype to v1.0.2 - fix scan pointer to pointer +## SendBatch Uses Pipeline Mode When Appropriate -# 4.1.1 (October 21, 2019) +Previously, a batch with 10 unique parameterized statements executed 100 times would entail 11 network round trips. 1 +for each prepare / describe and 1 for executing them all. Now pipeline mode is used to prepare / describe all statements +in a single network round trip. So it would only take 2 round trips. -* Fix pgxpool Rows.CommandTag() infinite loop / typo +## Tracing and Logging -# 4.1.0 (October 12, 2019) +Internal logging support has been replaced with tracing hooks. This allows custom tracing integration with tools like OpenTelemetry. Package tracelog provides an adapter for pgx v4 loggers to act as a tracer. -## Potentially Breaking Changes - -Technically, two changes are breaking changes, but in practice these are extremely unlikely to break existing code. - -* Conn.Begin and Conn.BeginTx return a Tx interface instead of the internal dbTx struct. This is necessary for the Conn.Begin method to signature as other methods that begin a transaction. -* Add Conn() to Tx interface. This is necessary to allow code using a Tx to access the *Conn (and pgconn.PgConn) on which the Tx is executing. - -## Fixes - -* Releasing a busy connection closes the connection instead of returning an unusable connection to the pool -* Do not mutate config.Config.OnNotification in connect - -# 4.0.1 (September 19, 2019) - -* Fix statement cache cleanup. -* Corrected daterange OID. -* Fix Tx when committing or rolling back multiple times in certain cases. -* Improve documentation. - -# 4.0.0 (September 14, 2019) - -v4 is a major release with many significant changes some of which are breaking changes. The most significant are -included below. - -* Simplified establishing a connection with a connection string. -* All potentially blocking operations now require a context.Context. The non-context aware functions have been removed. -* OIDs are hard-coded for known types. This saves the query on connection. -* Context cancellations while network activity is in progress is now always fatal. Previously, it was sometimes recoverable. This led to increased complexity in pgx itself and in application code. -* Go modules are required. -* Errors are now implemented in the Go 1.13 style. -* `Rows` and `Tx` are now interfaces. -* The connection pool as been decoupled from pgx and is now a separate, included package (github.com/jackc/pgx/v4/pgxpool). -* pgtype has been spun off to a separate package (github.com/jackc/pgtype). -* pgproto3 has been spun off to a separate package (github.com/jackc/pgproto3/v2). -* Logical replication support has been spun off to a separate package (github.com/jackc/pglogrepl). -* Lower level PostgreSQL functionality is now implemented in a separate package (github.com/jackc/pgconn). -* Tests are now configured with environment variables. -* Conn has an automatic statement cache by default. -* Batch interface has been simplified. -* QueryArgs has been removed. +All integrations with 3rd party loggers have been extracted to separate repositories. This trims the pgx dependency +tree. diff --git a/README.md b/README.md index 16d8f46f..fbffcf7e 100644 --- a/README.md +++ b/README.md @@ -1,25 +1,17 @@ -[![](https://godoc.org/github.com/jackc/pgx?status.svg)](https://pkg.go.dev/github.com/jackc/pgx/v4) -[![Build Status](https://travis-ci.org/jackc/pgx.svg)](https://travis-ci.org/jackc/pgx) +[![](https://godoc.org/github.com/jackc/pgx?status.svg)](https://pkg.go.dev/github.com/jackc/pgx/v5) +![Build Status](https://github.com/jackc/pgx/actions/workflows/ci.yml/badge.svg) ---- - -This is the stable `v4` release. `v5` is now in beta testing with final release expected in September. See https://github.com/jackc/pgx/issues/1273 for more information. Please consider testing `v5`. - ---- # pgx - PostgreSQL Driver and Toolkit pgx is a pure Go driver and toolkit for PostgreSQL. -pgx aims to be low-level, fast, and performant, while also enabling PostgreSQL-specific features that the standard `database/sql` package does not allow for. - -The driver component of pgx can be used alongside the standard `database/sql` package. +The pgx driver is a low-level, high performance interface that exposes PostgreSQL-specific features such as `LISTEN` / +`NOTIFY` and `COPY`. It also includes an adapter for the standard `database/sql` interface. The toolkit component is a related set of packages that implement PostgreSQL functionality such as parsing the wire protocol and type mapping between PostgreSQL and Go. These underlying packages can be used to implement alternative drivers, proxies, load balancers, logical replication clients, etc. -The current release of `pgx v4` requires Go modules. To use the previous version, checkout and vendor the `v3` branch. - ## Example Usage ```go @@ -30,7 +22,7 @@ import ( "fmt" "os" - "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v5" ) func main() { @@ -56,52 +48,39 @@ func main() { See the [getting started guide](https://github.com/jackc/pgx/wiki/Getting-started-with-pgx) for more information. -## Choosing Between the pgx and database/sql Interfaces - -It is recommended to use the pgx interface if: -1. The application only targets PostgreSQL. -2. No other libraries that require `database/sql` are in use. - -The pgx interface is faster and exposes more features. - -The `database/sql` interface only allows the underlying driver to return or receive the following types: `int64`, -`float64`, `bool`, `[]byte`, `string`, `time.Time`, or `nil`. Handling other types requires implementing the -`database/sql.Scanner` and the `database/sql/driver/driver.Valuer` interfaces which require transmission of values in text format. The binary format can be substantially faster, which is what the pgx interface uses. - ## Features -pgx supports many features beyond what is available through `database/sql`: - * Support for approximately 70 different PostgreSQL types * Automatic statement preparation and caching * Batch queries * Single-round trip query mode * Full TLS connection control * Binary format support for custom types (allows for much quicker encoding/decoding) -* COPY protocol support for faster bulk data loads -* Extendable logging support including built-in support for `log15adapter`, [`logrus`](https://github.com/sirupsen/logrus), [`zap`](https://github.com/uber-go/zap), and [`zerolog`](https://github.com/rs/zerolog) +* `COPY` protocol support for faster bulk data loads +* Tracing and logging support * Connection pool with after-connect hook for arbitrary connection setup -* Listen / notify +* `LISTEN` / `NOTIFY` * Conversion of PostgreSQL arrays to Go slice mappings for integers, floats, and strings -* Hstore support -* JSON and JSONB support -* Maps `inet` and `cidr` PostgreSQL types to `net.IPNet` and `net.IP` +* `hstore` support +* `json` and `jsonb` support +* Maps `inet` and `cidr` PostgreSQL types to `netip.Addr` and `netip.Prefix` * Large object support -* NULL mapping to Null* struct or pointer to pointer +* NULL mapping to pointer to pointer * Supports `database/sql.Scanner` and `database/sql/driver.Valuer` interfaces for custom types * Notice response handling * Simulated nested transactions with savepoints -## Performance +## Choosing Between the pgx and database/sql Interfaces -There are three areas in particular where pgx can provide a significant performance advantage over the standard -`database/sql` interface and other drivers: +The pgx interface is faster. Many PostgreSQL specific features such as `LISTEN` / `NOTIFY` and `COPY` are not available +through the `database/sql` interface. -1. PostgreSQL specific types - Types such as arrays can be parsed much quicker because pgx uses the binary format. -2. Automatic statement preparation and caching - pgx will prepare and cache statements by default. This can provide an - significant free improvement to code that does not explicitly use prepared statements. Under certain workloads, it can - perform nearly 3x the number of queries per second. -3. Batched queries - Multiple queries can be batched together to minimize network round trips. +The pgx interface is recommended when: + +1. The application only targets PostgreSQL. +2. No other libraries that require `database/sql` are in use. + +It is also possible to use the `database/sql` interface and convert a connection to the lower-level pgx interface as needed. ## Testing @@ -134,37 +113,14 @@ In addition, there are tests specific for PgBouncer that will be executed if `PG ## Supported Go and PostgreSQL Versions -pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.16 and higher and PostgreSQL 10 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/). +pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.18 and higher and PostgreSQL 10 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/). ## Version Policy -pgx follows semantic versioning for the documented public API on stable releases. `v4` is the latest stable major version. +pgx follows semantic versioning for the documented public API on stable releases. `v5` is the latest stable major version. ## PGX Family Libraries -pgx is the head of a family of PostgreSQL libraries. Many of these can be used independently. Many can also be accessed -from pgx for lower-level control. - -### [github.com/jackc/pgconn](https://github.com/jackc/pgconn) - -`pgconn` is a lower-level PostgreSQL database driver that operates at nearly the same level as the C library `libpq`. - -### [github.com/jackc/pgx/v4/pgxpool](https://github.com/jackc/pgx/tree/master/pgxpool) - -`pgxpool` is a connection pool for pgx. pgx is entirely decoupled from its default pool implementation. This means that pgx can be used with a different pool or without any pool at all. - -### [github.com/jackc/pgx/v4/stdlib](https://github.com/jackc/pgx/tree/master/stdlib) - -This is a `database/sql` compatibility layer for pgx. pgx can be used as a normal `database/sql` driver, but at any time, the native interface can be acquired for more performance or PostgreSQL specific functionality. - -### [github.com/jackc/pgtype](https://github.com/jackc/pgtype) - -Over 70 PostgreSQL types are supported including `uuid`, `hstore`, `json`, `bytea`, `numeric`, `interval`, `inet`, and arrays. These types support `database/sql` interfaces and are usable outside of pgx. They are fully tested in pgx and pq. They also support a higher performance interface when used with the pgx driver. - -### [github.com/jackc/pgproto3](https://github.com/jackc/pgproto3) - -pgproto3 provides standalone encoding and decoding of the PostgreSQL v3 wire protocol. This is useful for implementing very low level PostgreSQL tooling. - ### [github.com/jackc/pglogrepl](https://github.com/jackc/pglogrepl) pglogrepl provides functionality to act as a client for PostgreSQL logical replication. @@ -181,6 +137,22 @@ tern is a stand-alone SQL migration system. pgerrcode contains constants for the PostgreSQL error codes. +## Adapters for 3rd Party Types + +* [github.com/jackc/pgx-gofrs-uuid](https://github.com/jackc/pgx-gofrs-uuid) +* [github.com/jackc/pgx-shopspring-decimal](https://github.com/jackc/pgx-shopspring-decimal) +* [github.com/vgarvardt/pgx-google-uuid](https://github.com/vgarvardt/pgx-google-uuid) + +## Adapters for 3rd Party Loggers + +These adapters can be used with the tracelog package. + +* [github.com/jackc/pgx-go-kit-log](https://github.com/jackc/pgx-go-kit-log) +* [github.com/jackc/pgx-log15](https://github.com/jackc/pgx-log15) +* [github.com/jackc/pgx-logrus](https://github.com/jackc/pgx-logrus) +* [github.com/jackc/pgx-zap](https://github.com/jackc/pgx-zap) +* [github.com/jackc/pgx-zerolog](https://github.com/jackc/pgx-zerolog) + ## 3rd Party Libraries with PGX Support ### [github.com/georgysavva/scany](https://github.com/georgysavva/scany) @@ -190,7 +162,3 @@ Library for scanning data from a database into Go structs and more. ### [https://github.com/otan/gopgkrb5](https://github.com/otan/gopgkrb5) Adds GSSAPI / Kerberos authentication support. - -### [https://github.com/vgarvardt/pgx-google-uuid](https://github.com/vgarvardt/pgx-google-uuid) - -Adds support for [`github.com/google/uuid`](https://github.com/google/uuid). diff --git a/Rakefile b/Rakefile new file mode 100644 index 00000000..d957573e --- /dev/null +++ b/Rakefile @@ -0,0 +1,18 @@ +require "erb" + +rule '.go' => '.go.erb' do |task| + erb = ERB.new(File.read(task.source)) + File.write(task.name, "// Do not edit. Generated from #{task.source}\n" + erb.result(binding)) + sh "goimports", "-w", task.name +end + +generated_code_files = [ + "pgtype/int.go", + "pgtype/int_test.go", + "pgtype/integration_benchmark_test.go", + "pgtype/zeronull/int.go", + "pgtype/zeronull/int_test.go" +] + +desc "Generate code" +task generate: generated_code_files diff --git a/batch.go b/batch.go index 7f86ad5c..af62039f 100644 --- a/batch.go +++ b/batch.go @@ -5,69 +5,123 @@ import ( "errors" "fmt" - "github.com/jackc/pgconn" + "github.com/jackc/pgx/v5/pgconn" ) -type batchItem struct { +// QueuedQuery is a query that has been queued for execution via a Batch. +type QueuedQuery struct { query string - arguments []interface{} + arguments []any + fn batchItemFunc + sd *pgconn.StatementDescription +} + +type batchItemFunc func(br BatchResults) error + +// Query sets fn to be called when the response to qq is received. +func (qq *QueuedQuery) Query(fn func(rows Rows) error) { + qq.fn = func(br BatchResults) error { + rows, err := br.Query() + if err != nil { + return err + } + defer rows.Close() + + err = fn(rows) + if err != nil { + return err + } + rows.Close() + + return rows.Err() + } +} + +// Query sets fn to be called when the response to qq is received. +func (qq *QueuedQuery) QueryRow(fn func(row Row) error) { + qq.fn = func(br BatchResults) error { + row := br.QueryRow() + return fn(row) + } +} + +// Exec sets fn to be called when the response to qq is received. +func (qq *QueuedQuery) Exec(fn func(ct pgconn.CommandTag) error) { + qq.fn = func(br BatchResults) error { + ct, err := br.Exec() + if err != nil { + return err + } + + return fn(ct) + } } // Batch queries are a way of bundling multiple queries together to avoid -// unnecessary network round trips. +// unnecessary network round trips. A Batch must only be sent once. type Batch struct { - items []*batchItem + queuedQueries []*QueuedQuery } // Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement. -func (b *Batch) Queue(query string, arguments ...interface{}) { - b.items = append(b.items, &batchItem{ +func (b *Batch) Queue(query string, arguments ...any) *QueuedQuery { + qq := &QueuedQuery{ query: query, arguments: arguments, - }) + } + b.queuedQueries = append(b.queuedQueries, qq) + return qq } // Len returns number of queries that have been queued so far. func (b *Batch) Len() int { - return len(b.items) + return len(b.queuedQueries) } type BatchResults interface { - // Exec reads the results from the next query in the batch as if the query has been sent with Conn.Exec. + // Exec reads the results from the next query in the batch as if the query has been sent with Conn.Exec. Prefer + // calling Exec on the QueuedQuery. Exec() (pgconn.CommandTag, error) - // Query reads the results from the next query in the batch as if the query has been sent with Conn.Query. + // Query reads the results from the next query in the batch as if the query has been sent with Conn.Query. Prefer + // calling Query on the QueuedQuery. Query() (Rows, error) // QueryRow reads the results from the next query in the batch as if the query has been sent with Conn.QueryRow. + // Prefer calling QueryRow on the QueuedQuery. QueryRow() Row - // QueryFunc reads the results from the next query in the batch as if the query has been sent with Conn.QueryFunc. - QueryFunc(scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) - - // Close closes the batch operation. This must be called before the underlying connection can be used again. Any error - // that occurred during a batch operation may have made it impossible to resyncronize the connection with the server. - // In this case the underlying connection will have been closed. Close is safe to call multiple times. + // Close closes the batch operation. All unread results are read and any callback functions registered with + // QueuedQuery.Query, QueuedQuery.QueryRow, or QueuedQuery.Exec will be called. If a callback function returns an + // error or the batch encounters an error subsequent callback functions will not be called. + // + // Close must be called before the underlying connection can be used again. Any error that occurred during a batch + // operation may have made it impossible to resyncronize the connection with the server. In this case the underlying + // connection will have been closed. + // + // Close is safe to call multiple times. If it returns an error subsequent calls will return the same error. Callback + // functions will not be rerun. Close() error } type batchResults struct { - ctx context.Context - conn *Conn - mrr *pgconn.MultiResultReader - err error - b *Batch - ix int - closed bool + ctx context.Context + conn *Conn + mrr *pgconn.MultiResultReader + err error + b *Batch + qqIdx int + closed bool + endTraced bool } // Exec reads the results from the next query in the batch as if the query has been sent with Exec. func (br *batchResults) Exec() (pgconn.CommandTag, error) { if br.err != nil { - return nil, br.err + return pgconn.CommandTag{}, br.err } if br.closed { - return nil, fmt.Errorf("batch already closed") + return pgconn.CommandTag{}, fmt.Errorf("batch already closed") } query, arguments, _ := br.nextQueryAndArgs() @@ -77,35 +131,29 @@ func (br *batchResults) Exec() (pgconn.CommandTag, error) { if err == nil { err = errors.New("no result") } - if br.conn.shouldLog(LogLevelError) { - br.conn.log(br.ctx, LogLevelError, "BatchResult.Exec", map[string]interface{}{ - "sql": query, - "args": logQueryArgs(arguments), - "err": err, + if br.conn.batchTracer != nil { + br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{ + SQL: query, + Args: arguments, + Err: err, }) } - return nil, err + return pgconn.CommandTag{}, err } commandTag, err := br.mrr.ResultReader().Close() + br.err = err - if err != nil { - if br.conn.shouldLog(LogLevelError) { - br.conn.log(br.ctx, LogLevelError, "BatchResult.Exec", map[string]interface{}{ - "sql": query, - "args": logQueryArgs(arguments), - "err": err, - }) - } - } else if br.conn.shouldLog(LogLevelInfo) { - br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Exec", map[string]interface{}{ - "sql": query, - "args": logQueryArgs(arguments), - "commandTag": commandTag, + if br.conn.batchTracer != nil { + br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{ + SQL: query, + Args: arguments, + CommandTag: commandTag, + Err: br.err, }) } - return commandTag, err + return commandTag, br.err } // Query reads the results from the next query in the batch as if the query has been sent with Query. @@ -116,15 +164,16 @@ func (br *batchResults) Query() (Rows, error) { } if br.err != nil { - return &connRows{err: br.err, closed: true}, br.err + return &baseRows{err: br.err, closed: true}, br.err } if br.closed { alreadyClosedErr := fmt.Errorf("batch already closed") - return &connRows{err: alreadyClosedErr, closed: true}, alreadyClosedErr + return &baseRows{err: alreadyClosedErr, closed: true}, alreadyClosedErr } rows := br.conn.getRows(br.ctx, query, arguments) + rows.batchTracer = br.conn.batchTracer if !br.mrr.NextResult() { rows.err = br.mrr.Close() @@ -133,11 +182,11 @@ func (br *batchResults) Query() (Rows, error) { } rows.closed = true - if br.conn.shouldLog(LogLevelError) { - br.conn.log(br.ctx, LogLevelError, "BatchResult.Query", map[string]interface{}{ - "sql": query, - "args": logQueryArgs(arguments), - "err": rows.err, + if br.conn.batchTracer != nil { + br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{ + SQL: query, + Args: arguments, + Err: rows.err, }) } @@ -148,47 +197,25 @@ func (br *batchResults) Query() (Rows, error) { return rows, nil } -// QueryFunc reads the results from the next query in the batch as if the query has been sent with Conn.QueryFunc. -func (br *batchResults) QueryFunc(scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { - if br.closed { - return nil, fmt.Errorf("batch already closed") - } - - rows, err := br.Query() - if err != nil { - return nil, err - } - defer rows.Close() - - for rows.Next() { - err = rows.Scan(scans...) - if err != nil { - return nil, err - } - - err = f(rows) - if err != nil { - return nil, err - } - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return rows.CommandTag(), nil -} - // QueryRow reads the results from the next query in the batch as if the query has been sent with QueryRow. func (br *batchResults) QueryRow() Row { rows, _ := br.Query() - return (*connRow)(rows.(*connRows)) + return (*connRow)(rows.(*baseRows)) } // Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to // resyncronize the connection with the server. In this case the underlying connection will have been closed. func (br *batchResults) Close() error { + defer func() { + if !br.endTraced { + if br.conn != nil && br.conn.batchTracer != nil { + br.conn.batchTracer.TraceBatchEnd(br.ctx, br.conn, TraceBatchEndData{Err: br.err}) + } + br.endTraced = true + } + }() + if br.err != nil { return br.err } @@ -196,33 +223,213 @@ func (br *batchResults) Close() error { if br.closed { return nil } - br.closed = true - // log any queries that haven't yet been logged by Exec or Query - for { - query, args, ok := br.nextQueryAndArgs() - if !ok { - break - } - - if br.conn.shouldLog(LogLevelInfo) { - br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Close", map[string]interface{}{ - "sql": query, - "args": logQueryArgs(args), - }) + // Read and run fn for all remaining items + for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) { + if br.b.queuedQueries[br.qqIdx].fn != nil { + err := br.b.queuedQueries[br.qqIdx].fn(br) + if err != nil && br.err == nil { + br.err = err + } + } else { + br.Exec() } } - return br.mrr.Close() + br.closed = true + + err := br.mrr.Close() + if br.err == nil { + br.err = err + } + + return br.err } -func (br *batchResults) nextQueryAndArgs() (query string, args []interface{}, ok bool) { - if br.b != nil && br.ix < len(br.b.items) { - bi := br.b.items[br.ix] +func (br *batchResults) earlyError() error { + return br.err +} + +func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) { + if br.b != nil && br.qqIdx < len(br.b.queuedQueries) { + bi := br.b.queuedQueries[br.qqIdx] query = bi.query args = bi.arguments ok = true - br.ix++ + br.qqIdx++ + } + return +} + +type pipelineBatchResults struct { + ctx context.Context + conn *Conn + pipeline *pgconn.Pipeline + lastRows *baseRows + err error + b *Batch + qqIdx int + closed bool + endTraced bool +} + +// Exec reads the results from the next query in the batch as if the query has been sent with Exec. +func (br *pipelineBatchResults) Exec() (pgconn.CommandTag, error) { + if br.err != nil { + return pgconn.CommandTag{}, br.err + } + if br.closed { + return pgconn.CommandTag{}, fmt.Errorf("batch already closed") + } + if br.lastRows != nil && br.lastRows.err != nil { + return pgconn.CommandTag{}, br.err + } + + query, arguments, _ := br.nextQueryAndArgs() + + results, err := br.pipeline.GetResults() + if err != nil { + br.err = err + return pgconn.CommandTag{}, err + } + var commandTag pgconn.CommandTag + switch results := results.(type) { + case *pgconn.ResultReader: + commandTag, br.err = results.Close() + default: + return pgconn.CommandTag{}, fmt.Errorf("unexpected pipeline result: %T", results) + } + + if br.conn.batchTracer != nil { + br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{ + SQL: query, + Args: arguments, + CommandTag: commandTag, + Err: br.err, + }) + } + + return commandTag, err +} + +// Query reads the results from the next query in the batch as if the query has been sent with Query. +func (br *pipelineBatchResults) Query() (Rows, error) { + if br.err != nil { + return &baseRows{err: br.err, closed: true}, br.err + } + + if br.closed { + alreadyClosedErr := fmt.Errorf("batch already closed") + return &baseRows{err: alreadyClosedErr, closed: true}, alreadyClosedErr + } + + if br.lastRows != nil && br.lastRows.err != nil { + br.err = br.lastRows.err + return &baseRows{err: br.err, closed: true}, br.err + } + + query, arguments, ok := br.nextQueryAndArgs() + if !ok { + query = "batch query" + } + + rows := br.conn.getRows(br.ctx, query, arguments) + rows.batchTracer = br.conn.batchTracer + br.lastRows = rows + + results, err := br.pipeline.GetResults() + if err != nil { + br.err = err + rows.err = err + rows.closed = true + + if br.conn.batchTracer != nil { + br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{ + SQL: query, + Args: arguments, + Err: err, + }) + } + } else { + switch results := results.(type) { + case *pgconn.ResultReader: + rows.resultReader = results + default: + err = fmt.Errorf("unexpected pipeline result: %T", results) + br.err = err + rows.err = err + rows.closed = true + } + } + + return rows, rows.err +} + +// QueryRow reads the results from the next query in the batch as if the query has been sent with QueryRow. +func (br *pipelineBatchResults) QueryRow() Row { + rows, _ := br.Query() + return (*connRow)(rows.(*baseRows)) + +} + +// Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to +// resyncronize the connection with the server. In this case the underlying connection will have been closed. +func (br *pipelineBatchResults) Close() error { + defer func() { + if !br.endTraced { + if br.conn.batchTracer != nil { + br.conn.batchTracer.TraceBatchEnd(br.ctx, br.conn, TraceBatchEndData{Err: br.err}) + } + br.endTraced = true + } + }() + + if br.err != nil { + return br.err + } + + if br.lastRows != nil && br.lastRows.err != nil { + br.err = br.lastRows.err + return br.err + } + + if br.closed { + return nil + } + + // Read and run fn for all remaining items + for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) { + if br.b.queuedQueries[br.qqIdx].fn != nil { + err := br.b.queuedQueries[br.qqIdx].fn(br) + if err != nil && br.err == nil { + br.err = err + } + } else { + br.Exec() + } + } + + br.closed = true + + err := br.pipeline.Close() + if br.err == nil { + br.err = err + } + + return br.err +} + +func (br *pipelineBatchResults) earlyError() error { + return br.err +} + +func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, ok bool) { + if br.b != nil && br.qqIdx < len(br.b.queuedQueries) { + bi := br.b.queuedQueries[br.qqIdx] + query = bi.query + args = bi.arguments + ok = true + br.qqIdx++ } return } diff --git a/batch_test.go b/batch_test.go index f95e335e..2ade0d4a 100644 --- a/batch_test.go +++ b/batch_test.go @@ -3,12 +3,13 @@ package pgx_test import ( "context" "errors" + "fmt" "os" "testing" - "github.com/jackc/pgconn" - "github.com/jackc/pgconn/stmtcache" - "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxtest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -16,230 +17,351 @@ import ( func TestConnSendBatch(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "Server serial type is incompatible with test") - skipCockroachDB(t, conn, "Server serial type is incompatible with test") - - sql := `create temporary table ledger( + sql := `create temporary table ledger( id serial primary key, description varchar not null, amount int not null );` - mustExec(t, conn, sql) + mustExec(t, conn, sql) - batch := &pgx.Batch{} - batch.Queue("insert into ledger(description, amount) values($1, $2)", "q1", 1) - batch.Queue("insert into ledger(description, amount) values($1, $2)", "q2", 2) - batch.Queue("insert into ledger(description, amount) values($1, $2)", "q3", 3) - batch.Queue("select id, description, amount from ledger order by id") - batch.Queue("select id, description, amount from ledger order by id") - batch.Queue("select * from ledger where false") - batch.Queue("select sum(amount) from ledger") - - br := conn.SendBatch(context.Background(), batch) - - ct, err := br.Exec() - if err != nil { - t.Error(err) - } - if ct.RowsAffected() != 1 { - t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) - } - - ct, err = br.Exec() - if err != nil { - t.Error(err) - } - if ct.RowsAffected() != 1 { - t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) - } - - ct, err = br.Exec() - if err != nil { - t.Error(err) - } - if ct.RowsAffected() != 1 { - t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) - } - - selectFromLedgerExpectedRows := []struct { - id int32 - description string - amount int32 - }{ - {1, "q1", 1}, - {2, "q2", 2}, - {3, "q3", 3}, - } - - rows, err := br.Query() - if err != nil { - t.Error(err) - } - - var id int32 - var description string - var amount int32 - rowCount := 0 - - for rows.Next() { - if rowCount >= len(selectFromLedgerExpectedRows) { - t.Fatalf("got too many rows: %d", rowCount) - } - - if err := rows.Scan(&id, &description, &amount); err != nil { - t.Fatalf("row %d: %v", rowCount, err) - } - - if id != selectFromLedgerExpectedRows[rowCount].id { - t.Errorf("id => %v, want %v", id, selectFromLedgerExpectedRows[rowCount].id) - } - if description != selectFromLedgerExpectedRows[rowCount].description { - t.Errorf("description => %v, want %v", description, selectFromLedgerExpectedRows[rowCount].description) - } - if amount != selectFromLedgerExpectedRows[rowCount].amount { - t.Errorf("amount => %v, want %v", amount, selectFromLedgerExpectedRows[rowCount].amount) - } - - rowCount++ - } - - if rows.Err() != nil { - t.Fatal(rows.Err()) - } - - rowCount = 0 - _, err = br.QueryFunc([]interface{}{&id, &description, &amount}, func(pgx.QueryFuncRow) error { - if id != selectFromLedgerExpectedRows[rowCount].id { - t.Errorf("id => %v, want %v", id, selectFromLedgerExpectedRows[rowCount].id) - } - if description != selectFromLedgerExpectedRows[rowCount].description { - t.Errorf("description => %v, want %v", description, selectFromLedgerExpectedRows[rowCount].description) - } - if amount != selectFromLedgerExpectedRows[rowCount].amount { - t.Errorf("amount => %v, want %v", amount, selectFromLedgerExpectedRows[rowCount].amount) - } - - rowCount++ - - return nil - }) - if err != nil { - t.Error(err) - } - - err = br.QueryRow().Scan(&id, &description, &amount) - if !errors.Is(err, pgx.ErrNoRows) { - t.Errorf("expected pgx.ErrNoRows but got: %v", err) - } - - err = br.QueryRow().Scan(&amount) - if err != nil { - t.Error(err) - } - if amount != 6 { - t.Errorf("amount => %v, want %v", amount, 6) - } - - err = br.Close() - if err != nil { - t.Fatal(err) - } - - ensureConnValid(t, conn) -} - -func TestConnSendBatchMany(t *testing.T) { - t.Parallel() - - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) - - sql := `create temporary table ledger( - id serial primary key, - description varchar not null, - amount int not null - );` - mustExec(t, conn, sql) - - batch := &pgx.Batch{} - - numInserts := 1000 - - for i := 0; i < numInserts; i++ { + batch := &pgx.Batch{} batch.Queue("insert into ledger(description, amount) values($1, $2)", "q1", 1) - } - batch.Queue("select count(*) from ledger") + batch.Queue("insert into ledger(description, amount) values($1, $2)", "q2", 2) + batch.Queue("insert into ledger(description, amount) values($1, $2)", "q3", 3) + batch.Queue("select id, description, amount from ledger order by id") + batch.Queue("select id, description, amount from ledger order by id") + batch.Queue("select * from ledger where false") + batch.Queue("select sum(amount) from ledger") - br := conn.SendBatch(context.Background(), batch) + br := conn.SendBatch(context.Background(), batch) - for i := 0; i < numInserts; i++ { ct, err := br.Exec() - assert.NoError(t, err) - assert.EqualValues(t, 1, ct.RowsAffected()) - } + if err != nil { + t.Error(err) + } + if ct.RowsAffected() != 1 { + t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) + } - var actualInserts int - err := br.QueryRow().Scan(&actualInserts) - assert.NoError(t, err) - assert.EqualValues(t, numInserts, actualInserts) + ct, err = br.Exec() + if err != nil { + t.Error(err) + } + if ct.RowsAffected() != 1 { + t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) + } - err = br.Close() - require.NoError(t, err) + ct, err = br.Exec() + if err != nil { + t.Error(err) + } + if ct.RowsAffected() != 1 { + t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) + } - ensureConnValid(t, conn) -} + selectFromLedgerExpectedRows := []struct { + id int32 + description string + amount int32 + }{ + {1, "q1", 1}, + {2, "q2", 2}, + {3, "q3", 3}, + } -func TestConnSendBatchWithPreparedStatement(t *testing.T) { - t.Parallel() - - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) - - skipCockroachDB(t, conn, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") - - _, err := conn.Prepare(context.Background(), "ps1", "select n from generate_series(0,$1::int) n") - if err != nil { - t.Fatal(err) - } - - batch := &pgx.Batch{} - - queryCount := 3 - for i := 0; i < queryCount; i++ { - batch.Queue("ps1", 5) - } - - br := conn.SendBatch(context.Background(), batch) - - for i := 0; i < queryCount; i++ { rows, err := br.Query() if err != nil { - t.Fatal(err) + t.Error(err) } - for k := 0; rows.Next(); k++ { - var n int - if err := rows.Scan(&n); err != nil { - t.Fatal(err) + var id int32 + var description string + var amount int32 + rowCount := 0 + + for rows.Next() { + if rowCount >= len(selectFromLedgerExpectedRows) { + t.Fatalf("got too many rows: %d", rowCount) } - if n != k { - t.Fatalf("n => %v, want %v", n, k) + + if err := rows.Scan(&id, &description, &amount); err != nil { + t.Fatalf("row %d: %v", rowCount, err) } + + if id != selectFromLedgerExpectedRows[rowCount].id { + t.Errorf("id => %v, want %v", id, selectFromLedgerExpectedRows[rowCount].id) + } + if description != selectFromLedgerExpectedRows[rowCount].description { + t.Errorf("description => %v, want %v", description, selectFromLedgerExpectedRows[rowCount].description) + } + if amount != selectFromLedgerExpectedRows[rowCount].amount { + t.Errorf("amount => %v, want %v", amount, selectFromLedgerExpectedRows[rowCount].amount) + } + + rowCount++ } if rows.Err() != nil { t.Fatal(rows.Err()) } - } - err = br.Close() - if err != nil { - t.Fatal(err) - } + rowCount = 0 + rows, _ = br.Query() + _, err = pgx.ForEachRow(rows, []any{&id, &description, &amount}, func() error { + if id != selectFromLedgerExpectedRows[rowCount].id { + t.Errorf("id => %v, want %v", id, selectFromLedgerExpectedRows[rowCount].id) + } + if description != selectFromLedgerExpectedRows[rowCount].description { + t.Errorf("description => %v, want %v", description, selectFromLedgerExpectedRows[rowCount].description) + } + if amount != selectFromLedgerExpectedRows[rowCount].amount { + t.Errorf("amount => %v, want %v", amount, selectFromLedgerExpectedRows[rowCount].amount) + } - ensureConnValid(t, conn) + rowCount++ + + return nil + }) + if err != nil { + t.Error(err) + } + + err = br.QueryRow().Scan(&id, &description, &amount) + if !errors.Is(err, pgx.ErrNoRows) { + t.Errorf("expected pgx.ErrNoRows but got: %v", err) + } + + err = br.QueryRow().Scan(&amount) + if err != nil { + t.Error(err) + } + if amount != 6 { + t.Errorf("amount => %v, want %v", amount, 6) + } + + err = br.Close() + if err != nil { + t.Fatal(err) + } + }) +} + +func TestConnSendBatchQueuedQuery(t *testing.T) { + t.Parallel() + + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "Server serial type is incompatible with test") + + sql := `create temporary table ledger( + id serial primary key, + description varchar not null, + amount int not null + );` + mustExec(t, conn, sql) + + batch := &pgx.Batch{} + + batch.Queue("insert into ledger(description, amount) values($1, $2)", "q1", 1).Exec(func(ct pgconn.CommandTag) error { + assert.EqualValues(t, 1, ct.RowsAffected()) + return nil + }) + + batch.Queue("insert into ledger(description, amount) values($1, $2)", "q2", 2).Exec(func(ct pgconn.CommandTag) error { + assert.EqualValues(t, 1, ct.RowsAffected()) + return nil + }) + + batch.Queue("insert into ledger(description, amount) values($1, $2)", "q3", 3).Exec(func(ct pgconn.CommandTag) error { + assert.EqualValues(t, 1, ct.RowsAffected()) + return nil + }) + + selectFromLedgerExpectedRows := []struct { + id int32 + description string + amount int32 + }{ + {1, "q1", 1}, + {2, "q2", 2}, + {3, "q3", 3}, + } + + batch.Queue("select id, description, amount from ledger order by id").Query(func(rows pgx.Rows) error { + rowCount := 0 + var id int32 + var description string + var amount int32 + _, err := pgx.ForEachRow(rows, []any{&id, &description, &amount}, func() error { + assert.Equal(t, selectFromLedgerExpectedRows[rowCount].id, id) + assert.Equal(t, selectFromLedgerExpectedRows[rowCount].description, description) + assert.Equal(t, selectFromLedgerExpectedRows[rowCount].amount, amount) + rowCount++ + + return nil + }) + assert.NoError(t, err) + return nil + }) + + batch.Queue("select id, description, amount from ledger order by id").Query(func(rows pgx.Rows) error { + rowCount := 0 + var id int32 + var description string + var amount int32 + _, err := pgx.ForEachRow(rows, []any{&id, &description, &amount}, func() error { + assert.Equal(t, selectFromLedgerExpectedRows[rowCount].id, id) + assert.Equal(t, selectFromLedgerExpectedRows[rowCount].description, description) + assert.Equal(t, selectFromLedgerExpectedRows[rowCount].amount, amount) + rowCount++ + + return nil + }) + assert.NoError(t, err) + return nil + }) + + batch.Queue("select * from ledger where false").QueryRow(func(row pgx.Row) error { + err := row.Scan(nil, nil, nil) + assert.ErrorIs(t, err, pgx.ErrNoRows) + return nil + }) + + batch.Queue("select sum(amount) from ledger").QueryRow(func(row pgx.Row) error { + var sumAmount int32 + err := row.Scan(&sumAmount) + assert.NoError(t, err) + assert.EqualValues(t, 6, sumAmount) + return nil + }) + + err := conn.SendBatch(context.Background(), batch).Close() + assert.NoError(t, err) + }) +} + +func TestConnSendBatchMany(t *testing.T) { + t.Parallel() + + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + sql := `create temporary table ledger( + id serial primary key, + description varchar not null, + amount int not null + );` + mustExec(t, conn, sql) + + batch := &pgx.Batch{} + + numInserts := 1000 + + for i := 0; i < numInserts; i++ { + batch.Queue("insert into ledger(description, amount) values($1, $2)", "q1", 1) + } + batch.Queue("select count(*) from ledger") + + br := conn.SendBatch(context.Background(), batch) + + for i := 0; i < numInserts; i++ { + ct, err := br.Exec() + assert.NoError(t, err) + assert.EqualValues(t, 1, ct.RowsAffected()) + } + + var actualInserts int + err := br.QueryRow().Scan(&actualInserts) + assert.NoError(t, err) + assert.EqualValues(t, numInserts, actualInserts) + + err = br.Close() + require.NoError(t, err) + }) +} + +func TestConnSendBatchWithPreparedStatement(t *testing.T) { + t.Parallel() + + modes := []pgx.QueryExecMode{ + pgx.QueryExecModeCacheStatement, + pgx.QueryExecModeCacheDescribe, + pgx.QueryExecModeDescribeExec, + pgx.QueryExecModeExec, + // Don't test simple mode with prepared statements. + } + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, modes, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") + _, err := conn.Prepare(context.Background(), "ps1", "select n from generate_series(0,$1::int) n") + if err != nil { + t.Fatal(err) + } + + batch := &pgx.Batch{} + + queryCount := 3 + for i := 0; i < queryCount; i++ { + batch.Queue("ps1", 5) + } + + br := conn.SendBatch(context.Background(), batch) + + for i := 0; i < queryCount; i++ { + rows, err := br.Query() + if err != nil { + t.Fatal(err) + } + + for k := 0; rows.Next(); k++ { + var n int + if err := rows.Scan(&n); err != nil { + t.Fatal(err) + } + if n != k { + t.Fatalf("n => %v, want %v", n, k) + } + } + + if rows.Err() != nil { + t.Fatal(rows.Err()) + } + } + + err = br.Close() + if err != nil { + t.Fatal(err) + } + }) +} + +func TestConnSendBatchWithQueryRewriter(t *testing.T) { + t.Parallel() + + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + batch := &pgx.Batch{} + batch.Queue("something to be replaced", &testQueryRewriter{sql: "select $1::int", args: []any{1}}) + batch.Queue("something else to be replaced", &testQueryRewriter{sql: "select $1::text", args: []any{"hello"}}) + batch.Queue("more to be replaced", &testQueryRewriter{sql: "select $1::int", args: []any{3}}) + + br := conn.SendBatch(context.Background(), batch) + + var n int32 + err := br.QueryRow().Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 1, n) + + var s string + err = br.QueryRow().Scan(&s) + require.NoError(t, err) + require.Equal(t, "hello", s) + + err = br.QueryRow().Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 3, n) + + err = br.Close() + require.NoError(t, err) + }) } // https://github.com/jackc/pgx/issues/856 @@ -249,12 +371,14 @@ func TestConnSendBatchWithPreparedStatementAndStatementCacheDisabled(t *testing. config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) - config.BuildStatementCache = nil + config.DefaultQueryExecMode = pgx.QueryExecModeDescribeExec + config.StatementCacheCapacity = 0 + config.DescriptionCacheCapacity = 0 conn := mustConnect(t, config) defer closeConn(t, conn) - skipCockroachDB(t, conn, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") + pgxtest.SkipCockroachDB(t, conn, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") _, err = conn.Prepare(context.Background(), "ps1", "select n from generate_series(0,$1::int) n") if err != nil { @@ -302,316 +426,308 @@ func TestConnSendBatchWithPreparedStatementAndStatementCacheDisabled(t *testing. func TestConnSendBatchCloseRowsPartiallyRead(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - batch := &pgx.Batch{} - batch.Queue("select n from generate_series(0,5) n") - batch.Queue("select n from generate_series(0,5) n") + batch := &pgx.Batch{} + batch.Queue("select n from generate_series(0,5) n") + batch.Queue("select n from generate_series(0,5) n") - br := conn.SendBatch(context.Background(), batch) + br := conn.SendBatch(context.Background(), batch) - rows, err := br.Query() - if err != nil { - t.Error(err) - } - - for i := 0; i < 3; i++ { - if !rows.Next() { - t.Error("expected a row to be available") - } - - var n int - if err := rows.Scan(&n); err != nil { + rows, err := br.Query() + if err != nil { t.Error(err) } - if n != i { - t.Errorf("n => %v, want %v", n, i) + + for i := 0; i < 3; i++ { + if !rows.Next() { + t.Error("expected a row to be available") + } + + var n int + if err := rows.Scan(&n); err != nil { + t.Error(err) + } + if n != i { + t.Errorf("n => %v, want %v", n, i) + } } - } - rows.Close() + rows.Close() - rows, err = br.Query() - if err != nil { - t.Error(err) - } - - for i := 0; rows.Next(); i++ { - var n int - if err := rows.Scan(&n); err != nil { + rows, err = br.Query() + if err != nil { t.Error(err) } - if n != i { - t.Errorf("n => %v, want %v", n, i) + + for i := 0; rows.Next(); i++ { + var n int + if err := rows.Scan(&n); err != nil { + t.Error(err) + } + if n != i { + t.Errorf("n => %v, want %v", n, i) + } } - } - if rows.Err() != nil { - t.Error(rows.Err()) - } + if rows.Err() != nil { + t.Error(rows.Err()) + } - err = br.Close() - if err != nil { - t.Fatal(err) - } + err = br.Close() + if err != nil { + t.Fatal(err) + } - ensureConnValid(t, conn) + }) } func TestConnSendBatchQueryError(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - batch := &pgx.Batch{} - batch.Queue("select n from generate_series(0,5) n where 100/(5-n) > 0") - batch.Queue("select n from generate_series(0,5) n") + batch := &pgx.Batch{} + batch.Queue("select n from generate_series(0,5) n where 100/(5-n) > 0") + batch.Queue("select n from generate_series(0,5) n") - br := conn.SendBatch(context.Background(), batch) + br := conn.SendBatch(context.Background(), batch) - rows, err := br.Query() - if err != nil { - t.Error(err) - } - - for i := 0; rows.Next(); i++ { - var n int - if err := rows.Scan(&n); err != nil { + rows, err := br.Query() + if err != nil { t.Error(err) } - if n != i { - t.Errorf("n => %v, want %v", n, i) + + for i := 0; rows.Next(); i++ { + var n int + if err := rows.Scan(&n); err != nil { + t.Error(err) + } + if n != i { + t.Errorf("n => %v, want %v", n, i) + } } - } - if pgErr, ok := rows.Err().(*pgconn.PgError); !(ok && pgErr.Code == "22012") { - t.Errorf("rows.Err() => %v, want error code %v", rows.Err(), 22012) - } + if pgErr, ok := rows.Err().(*pgconn.PgError); !(ok && pgErr.Code == "22012") { + t.Errorf("rows.Err() => %v, want error code %v", rows.Err(), 22012) + } - err = br.Close() - if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "22012") { - t.Errorf("rows.Err() => %v, want error code %v", err, 22012) - } + err = br.Close() + if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "22012") { + t.Errorf("br.Close() => %v, want error code %v", err, 22012) + } - ensureConnValid(t, conn) + }) } func TestConnSendBatchQuerySyntaxError(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - batch := &pgx.Batch{} - batch.Queue("select 1 1") + batch := &pgx.Batch{} + batch.Queue("select 1 1") - br := conn.SendBatch(context.Background(), batch) + br := conn.SendBatch(context.Background(), batch) - var n int32 - err := br.QueryRow().Scan(&n) - if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "42601") { - t.Errorf("rows.Err() => %v, want error code %v", err, 42601) - } + var n int32 + err := br.QueryRow().Scan(&n) + if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "42601") { + t.Errorf("rows.Err() => %v, want error code %v", err, 42601) + } - err = br.Close() - if err == nil { - t.Error("Expected error") - } + err = br.Close() + if err == nil { + t.Error("Expected error") + } - ensureConnValid(t, conn) + }) } func TestConnSendBatchQueryRowInsert(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - sql := `create temporary table ledger( + sql := `create temporary table ledger( id serial primary key, description varchar not null, amount int not null );` - mustExec(t, conn, sql) + mustExec(t, conn, sql) - batch := &pgx.Batch{} - batch.Queue("select 1") - batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", "q1", 1) + batch := &pgx.Batch{} + batch.Queue("select 1") + batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", "q1", 1) - br := conn.SendBatch(context.Background(), batch) + br := conn.SendBatch(context.Background(), batch) - var value int - err := br.QueryRow().Scan(&value) - if err != nil { - t.Error(err) - } + var value int + err := br.QueryRow().Scan(&value) + if err != nil { + t.Error(err) + } - ct, err := br.Exec() - if err != nil { - t.Error(err) - } - if ct.RowsAffected() != 2 { - t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2) - } + ct, err := br.Exec() + if err != nil { + t.Error(err) + } + if ct.RowsAffected() != 2 { + t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2) + } - br.Close() + br.Close() - ensureConnValid(t, conn) + }) } func TestConnSendBatchQueryPartialReadInsert(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - sql := `create temporary table ledger( + sql := `create temporary table ledger( id serial primary key, description varchar not null, amount int not null );` - mustExec(t, conn, sql) + mustExec(t, conn, sql) - batch := &pgx.Batch{} - batch.Queue("select 1 union all select 2 union all select 3") - batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", "q1", 1) + batch := &pgx.Batch{} + batch.Queue("select 1 union all select 2 union all select 3") + batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", "q1", 1) - br := conn.SendBatch(context.Background(), batch) + br := conn.SendBatch(context.Background(), batch) - rows, err := br.Query() - if err != nil { - t.Error(err) - } - rows.Close() + rows, err := br.Query() + if err != nil { + t.Error(err) + } + rows.Close() - ct, err := br.Exec() - if err != nil { - t.Error(err) - } - if ct.RowsAffected() != 2 { - t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2) - } + ct, err := br.Exec() + if err != nil { + t.Error(err) + } + if ct.RowsAffected() != 2 { + t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2) + } - br.Close() + br.Close() - ensureConnValid(t, conn) + }) } func TestTxSendBatch(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - sql := `create temporary table ledger1( + sql := `create temporary table ledger1( id serial primary key, description varchar not null );` - mustExec(t, conn, sql) + mustExec(t, conn, sql) - sql = `create temporary table ledger2( + sql = `create temporary table ledger2( id int primary key, amount int not null );` - mustExec(t, conn, sql) + mustExec(t, conn, sql) - tx, _ := conn.Begin(context.Background()) - batch := &pgx.Batch{} - batch.Queue("insert into ledger1(description) values($1) returning id", "q1") + tx, _ := conn.Begin(context.Background()) + batch := &pgx.Batch{} + batch.Queue("insert into ledger1(description) values($1) returning id", "q1") - br := tx.SendBatch(context.Background(), batch) + br := tx.SendBatch(context.Background(), batch) - var id int - err := br.QueryRow().Scan(&id) - if err != nil { - t.Error(err) - } - br.Close() + var id int + err := br.QueryRow().Scan(&id) + if err != nil { + t.Error(err) + } + br.Close() - batch = &pgx.Batch{} - batch.Queue("insert into ledger2(id,amount) values($1, $2)", id, 2) - batch.Queue("select amount from ledger2 where id = $1", id) + batch = &pgx.Batch{} + batch.Queue("insert into ledger2(id,amount) values($1, $2)", id, 2) + batch.Queue("select amount from ledger2 where id = $1", id) - br = tx.SendBatch(context.Background(), batch) + br = tx.SendBatch(context.Background(), batch) - ct, err := br.Exec() - if err != nil { - t.Error(err) - } - if ct.RowsAffected() != 1 { - t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) - } + ct, err := br.Exec() + if err != nil { + t.Error(err) + } + if ct.RowsAffected() != 1 { + t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) + } - var amount int - err = br.QueryRow().Scan(&amount) - if err != nil { - t.Error(err) - } + var amount int + err = br.QueryRow().Scan(&amount) + if err != nil { + t.Error(err) + } - br.Close() - tx.Commit(context.Background()) + br.Close() + tx.Commit(context.Background()) - var count int - conn.QueryRow(context.Background(), "select count(1) from ledger1 where id = $1", id).Scan(&count) - if count != 1 { - t.Errorf("count => %v, want %v", count, 1) - } + var count int + conn.QueryRow(context.Background(), "select count(1) from ledger1 where id = $1", id).Scan(&count) + if count != 1 { + t.Errorf("count => %v, want %v", count, 1) + } - err = br.Close() - if err != nil { - t.Fatal(err) - } + err = br.Close() + if err != nil { + t.Fatal(err) + } - ensureConnValid(t, conn) + }) } func TestTxSendBatchRollback(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - sql := `create temporary table ledger1( + sql := `create temporary table ledger1( id serial primary key, description varchar not null );` - mustExec(t, conn, sql) + mustExec(t, conn, sql) - tx, _ := conn.Begin(context.Background()) - batch := &pgx.Batch{} - batch.Queue("insert into ledger1(description) values($1) returning id", "q1") + tx, _ := conn.Begin(context.Background()) + batch := &pgx.Batch{} + batch.Queue("insert into ledger1(description) values($1) returning id", "q1") - br := tx.SendBatch(context.Background(), batch) + br := tx.SendBatch(context.Background(), batch) - var id int - err := br.QueryRow().Scan(&id) - if err != nil { - t.Error(err) - } - br.Close() - tx.Rollback(context.Background()) + var id int + err := br.QueryRow().Scan(&id) + if err != nil { + t.Error(err) + } + br.Close() + tx.Rollback(context.Background()) - row := conn.QueryRow(context.Background(), "select count(1) from ledger1 where id = $1", id) - var count int - row.Scan(&count) - if count != 0 { - t.Errorf("count => %v, want %v", count, 0) - } + row := conn.QueryRow(context.Background(), "select count(1) from ledger1 where id = $1", id) + var count int + row.Scan(&count) + if count != 0 { + t.Errorf("count => %v, want %v", count, 0) + } - ensureConnValid(t, conn) + }) } func TestConnBeginBatchDeferredError(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - skipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") + pgxtest.SkipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") - mustExec(t, conn, `create temporary table t ( + mustExec(t, conn, `create temporary table t ( id text primary key, n int not null, unique (n) deferrable initially deferred @@ -619,41 +735,43 @@ func TestConnBeginBatchDeferredError(t *testing.T) { insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`) - batch := &pgx.Batch{} + batch := &pgx.Batch{} - batch.Queue(`update t set n=n+1 where id='b' returning *`) + batch.Queue(`update t set n=n+1 where id='b' returning *`) - br := conn.SendBatch(context.Background(), batch) + br := conn.SendBatch(context.Background(), batch) - rows, err := br.Query() - if err != nil { - t.Error(err) - } - - for rows.Next() { - var id string - var n int32 - err = rows.Scan(&id, &n) + rows, err := br.Query() if err != nil { - t.Fatal(err) + t.Error(err) } - } - err = br.Close() - if err == nil { - t.Fatal("expected error 23505 but got none") - } + for rows.Next() { + var id string + var n int32 + err = rows.Scan(&id, &n) + if err != nil { + t.Fatal(err) + } + } - if err, ok := err.(*pgconn.PgError); !ok || err.Code != "23505" { - t.Fatalf("expected error 23505, got %v", err) - } + err = br.Close() + if err == nil { + t.Fatal("expected error 23505 but got none") + } - ensureConnValid(t, conn) + if err, ok := err.(*pgconn.PgError); !ok || err.Code != "23505" { + t.Fatalf("expected error 23505, got %v", err) + } + + }) } func TestConnSendBatchNoStatementCache(t *testing.T) { config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) - config.BuildStatementCache = nil + config.DefaultQueryExecMode = pgx.QueryExecModeDescribeExec + config.StatementCacheCapacity = 0 + config.DescriptionCacheCapacity = 0 conn := mustConnect(t, config) defer closeConn(t, conn) @@ -663,9 +781,8 @@ func TestConnSendBatchNoStatementCache(t *testing.T) { func TestConnSendBatchPrepareStatementCache(t *testing.T) { config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) - config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache { - return stmtcache.New(conn, stmtcache.ModePrepare, 32) - } + config.DefaultQueryExecMode = pgx.QueryExecModeCacheStatement + config.StatementCacheCapacity = 32 conn := mustConnect(t, config) defer closeConn(t, conn) @@ -675,9 +792,8 @@ func TestConnSendBatchPrepareStatementCache(t *testing.T) { func TestConnSendBatchDescribeStatementCache(t *testing.T) { config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) - config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache { - return stmtcache.New(conn, stmtcache.ModeDescribe, 32) - } + config.DefaultQueryExecMode = pgx.QueryExecModeCacheDescribe + config.DescriptionCacheCapacity = 32 conn := mustConnect(t, config) defer closeConn(t, conn) @@ -711,107 +827,11 @@ func testConnSendBatch(t *testing.T, conn *pgx.Conn, queryCount int) { require.NoError(t, err) } -func TestLogBatchStatementsOnExec(t *testing.T) { - l1 := &testLogger{} - config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) - config.Logger = l1 - - conn := mustConnect(t, config) - defer closeConn(t, conn) - - l1.logs = l1.logs[0:0] // Clear logs written when establishing connection - - batch := &pgx.Batch{} - batch.Queue("create table foo (id bigint)") - batch.Queue("drop table foo") - - br := conn.SendBatch(context.Background(), batch) - - _, err := br.Exec() - if err != nil { - t.Fatalf("Unexpected error creating table: %v", err) - } - - _, err = br.Exec() - if err != nil { - t.Fatalf("Unexpected error dropping table: %v", err) - } - - if len(l1.logs) != 3 { - t.Fatalf("Expected two log entries but got %d", len(l1.logs)) - } - - if l1.logs[0].msg != "SendBatch" { - t.Errorf("Expected first log message to be 'SendBatch' but was '%s'", l1.logs[0].msg) - } - - if l1.logs[1].msg != "BatchResult.Exec" { - t.Errorf("Expected first log message to be 'BatchResult.Exec' but was '%s'", l1.logs[0].msg) - } - - if l1.logs[1].data["sql"] != "create table foo (id bigint)" { - t.Errorf("Expected the first query to be 'create table foo (id bigint)' but was '%s'", l1.logs[0].data["sql"]) - } - - if l1.logs[2].msg != "BatchResult.Exec" { - t.Errorf("Expected second log message to be 'BatchResult.Exec' but was '%s", l1.logs[1].msg) - } - - if l1.logs[2].data["sql"] != "drop table foo" { - t.Errorf("Expected the second query to be 'drop table foo' but was '%s'", l1.logs[1].data["sql"]) - } -} - -func TestLogBatchStatementsOnBatchResultClose(t *testing.T) { - l1 := &testLogger{} - config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) - config.Logger = l1 - - conn := mustConnect(t, config) - defer closeConn(t, conn) - - l1.logs = l1.logs[0:0] // Clear logs written when establishing connection - - batch := &pgx.Batch{} - batch.Queue("select generate_series(1,$1)", 100) - batch.Queue("select 1 = 1;") - - br := conn.SendBatch(context.Background(), batch) - - if err := br.Close(); err != nil { - t.Fatalf("Unexpected batch error: %v", err) - } - - if len(l1.logs) != 3 { - t.Fatalf("Expected 2 log statements but found %d", len(l1.logs)) - } - - if l1.logs[0].msg != "SendBatch" { - t.Errorf("Expected first log message to be 'SendBatch' but was '%s'", l1.logs[0].msg) - } - - if l1.logs[1].msg != "BatchResult.Close" { - t.Errorf("Expected first log statement to be 'BatchResult.Close' but was '%s'", l1.logs[0].msg) - } - - if l1.logs[1].data["sql"] != "select generate_series(1,$1)" { - t.Errorf("Expected first query to be 'select generate_series(1,$1)' but was '%s'", l1.logs[0].data["sql"]) - } - - if l1.logs[2].msg != "BatchResult.Close" { - t.Errorf("Expected second log statement to be 'BatchResult.Close' but was %s", l1.logs[1].msg) - } - - if l1.logs[2].data["sql"] != "select 1 = 1;" { - t.Errorf("Expected second query to be 'select 1 = 1;' but was '%s'", l1.logs[1].data["sql"]) - } -} - func TestSendBatchSimpleProtocol(t *testing.T) { t.Parallel() config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) - config.PreferSimpleProtocol = true + config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() @@ -847,3 +867,59 @@ func TestSendBatchSimpleProtocol(t *testing.T) { assert.EqualValues(t, 3, values[0]) assert.False(t, rows.Next()) } + +func ExampleConn_SendBatch() { + conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + fmt.Printf("Unable to establish connection: %v", err) + return + } + + batch := &pgx.Batch{} + batch.Queue("select 1 + 1").QueryRow(func(row pgx.Row) error { + var n int32 + err := row.Scan(&n) + if err != nil { + return err + } + + fmt.Println(n) + + return err + }) + + batch.Queue("select 1 + 2").QueryRow(func(row pgx.Row) error { + var n int32 + err := row.Scan(&n) + if err != nil { + return err + } + + fmt.Println(n) + + return err + }) + + batch.Queue("select 2 + 3").QueryRow(func(row pgx.Row) error { + var n int32 + err := row.Scan(&n) + if err != nil { + return err + } + + fmt.Println(n) + + return err + }) + + err = conn.SendBatch(context.Background(), batch).Close() + if err != nil { + fmt.Printf("SendBatch error: %v", err) + return + } + + // Output: + // 2 + // 3 + // 5 +} diff --git a/bench_test.go b/bench_test.go index f2d98bab..73e1b258 100644 --- a/bench_test.go +++ b/bench_test.go @@ -12,16 +12,32 @@ import ( "testing" "time" - "github.com/jackc/pgconn" - "github.com/jackc/pgconn/stmtcache" - "github.com/jackc/pgtype" - "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/internal/nbconn" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" "github.com/stretchr/testify/require" ) +func BenchmarkConnectClose(b *testing.B) { + for i := 0; i < b.N; i++ { + conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + b.Fatal(err) + } + + err = conn.Close(context.Background()) + if err != nil { + b.Fatal(err) + } + } +} + func BenchmarkMinimalUnpreparedSelectWithoutStatementCache(b *testing.B) { config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) - config.BuildStatementCache = nil + config.DefaultQueryExecMode = pgx.QueryExecModeDescribeExec + config.StatementCacheCapacity = 0 + config.DescriptionCacheCapacity = 0 conn := mustConnect(b, config) defer closeConn(b, conn) @@ -43,9 +59,9 @@ func BenchmarkMinimalUnpreparedSelectWithoutStatementCache(b *testing.B) { func BenchmarkMinimalUnpreparedSelectWithStatementCacheModeDescribe(b *testing.B) { config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) - config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache { - return stmtcache.New(conn, stmtcache.ModeDescribe, 32) - } + config.DefaultQueryExecMode = pgx.QueryExecModeCacheDescribe + config.StatementCacheCapacity = 0 + config.DescriptionCacheCapacity = 32 conn := mustConnect(b, config) defer closeConn(b, conn) @@ -67,9 +83,9 @@ func BenchmarkMinimalUnpreparedSelectWithStatementCacheModeDescribe(b *testing.B func BenchmarkMinimalUnpreparedSelectWithStatementCacheModePrepare(b *testing.B) { config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) - config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache { - return stmtcache.New(conn, stmtcache.ModePrepare, 32) - } + config.DefaultQueryExecMode = pgx.QueryExecModeCacheStatement + config.StatementCacheCapacity = 32 + config.DescriptionCacheCapacity = 0 conn := mustConnect(b, config) defer closeConn(b, conn) @@ -268,123 +284,6 @@ func BenchmarkPointerPointerWithPresentValues(b *testing.B) { } } -func BenchmarkSelectWithoutLogging(b *testing.B) { - conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))) - defer closeConn(b, conn) - - benchmarkSelectWithLog(b, conn) -} - -type discardLogger struct{} - -func (dl discardLogger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { -} - -func BenchmarkSelectWithLoggingTraceDiscard(b *testing.B) { - var logger discardLogger - config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) - config.Logger = logger - config.LogLevel = pgx.LogLevelTrace - - conn := mustConnect(b, config) - defer closeConn(b, conn) - - benchmarkSelectWithLog(b, conn) -} - -func BenchmarkSelectWithLoggingDebugWithDiscard(b *testing.B) { - var logger discardLogger - config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) - config.Logger = logger - config.LogLevel = pgx.LogLevelDebug - - conn := mustConnect(b, config) - defer closeConn(b, conn) - - benchmarkSelectWithLog(b, conn) -} - -func BenchmarkSelectWithLoggingInfoWithDiscard(b *testing.B) { - var logger discardLogger - config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) - config.Logger = logger - config.LogLevel = pgx.LogLevelInfo - - conn := mustConnect(b, config) - defer closeConn(b, conn) - - benchmarkSelectWithLog(b, conn) -} - -func BenchmarkSelectWithLoggingErrorWithDiscard(b *testing.B) { - var logger discardLogger - config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) - config.Logger = logger - config.LogLevel = pgx.LogLevelError - - conn := mustConnect(b, config) - defer closeConn(b, conn) - - benchmarkSelectWithLog(b, conn) -} - -func benchmarkSelectWithLog(b *testing.B, conn *pgx.Conn) { - _, err := conn.Prepare(context.Background(), "test", "select 1::int4, 'johnsmith', 'johnsmith@example.com', 'John Smith', 'male', '1970-01-01'::date, '2015-01-01 00:00:00'::timestamptz") - if err != nil { - b.Fatal(err) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - var record struct { - id int32 - userName string - email string - name string - sex string - birthDate time.Time - lastLoginTime time.Time - } - - err = conn.QueryRow(context.Background(), "test").Scan( - &record.id, - &record.userName, - &record.email, - &record.name, - &record.sex, - &record.birthDate, - &record.lastLoginTime, - ) - if err != nil { - b.Fatal(err) - } - - // These checks both ensure that the correct data was returned - // and provide a benchmark of accessing the returned values. - if record.id != 1 { - b.Fatalf("bad value for id: %v", record.id) - } - if record.userName != "johnsmith" { - b.Fatalf("bad value for userName: %v", record.userName) - } - if record.email != "johnsmith@example.com" { - b.Fatalf("bad value for email: %v", record.email) - } - if record.name != "John Smith" { - b.Fatalf("bad value for name: %v", record.name) - } - if record.sex != "male" { - b.Fatalf("bad value for sex: %v", record.sex) - } - if record.birthDate != time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC) { - b.Fatalf("bad value for birthDate: %v", record.birthDate) - } - if record.lastLoginTime != time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local) { - b.Fatalf("bad value for lastLoginTime: %v", record.lastLoginTime) - } - } -} - const benchmarkWriteTableCreateSQL = `drop table if exists t; create table t( @@ -437,7 +336,7 @@ const benchmarkWriteTableInsertSQL = `insert into t( type benchmarkWriteTableCopyFromSrc struct { count int idx int - row []interface{} + row []any } func (s *benchmarkWriteTableCopyFromSrc) Next() bool { @@ -445,7 +344,7 @@ func (s *benchmarkWriteTableCopyFromSrc) Next() bool { return s.idx < s.count } -func (s *benchmarkWriteTableCopyFromSrc) Values() ([]interface{}, error) { +func (s *benchmarkWriteTableCopyFromSrc) Values() ([]any, error) { return s.row, nil } @@ -456,15 +355,15 @@ func (s *benchmarkWriteTableCopyFromSrc) Err() error { func newBenchmarkWriteTableCopyFromSrc(count int) pgx.CopyFromSource { return &benchmarkWriteTableCopyFromSrc{ count: count, - row: []interface{}{ + row: []any{ "varchar_1", "varchar_2", - &pgtype.Text{Status: pgtype.Null}, + &pgtype.Text{}, time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), - &pgtype.Date{Status: pgtype.Null}, + &pgtype.Date{}, 1, 2, - &pgtype.Int4{Status: pgtype.Null}, + &pgtype.Int4{}, time.Date(2001, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2002, 1, 1, 0, 0, 0, 0, time.Local), true, @@ -508,9 +407,9 @@ func benchmarkWriteNRowsViaInsert(b *testing.B, n int) { } } -type queryArgs []interface{} +type queryArgs []any -func (qa *queryArgs) Append(v interface{}) string { +func (qa *queryArgs) Append(v any) string { *qa = append(*qa, v) return "$" + strconv.Itoa(len(*qa)) } @@ -723,7 +622,9 @@ func BenchmarkWrite10000RowsViaCopy(b *testing.B) { func BenchmarkMultipleQueriesNonBatchNoStatementCache(b *testing.B) { config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) - config.BuildStatementCache = nil + config.DefaultQueryExecMode = pgx.QueryExecModeDescribeExec + config.StatementCacheCapacity = 0 + config.DescriptionCacheCapacity = 0 conn := mustConnect(b, config) defer closeConn(b, conn) @@ -733,9 +634,9 @@ func BenchmarkMultipleQueriesNonBatchNoStatementCache(b *testing.B) { func BenchmarkMultipleQueriesNonBatchPrepareStatementCache(b *testing.B) { config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) - config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache { - return stmtcache.New(conn, stmtcache.ModePrepare, 32) - } + config.DefaultQueryExecMode = pgx.QueryExecModeCacheStatement + config.StatementCacheCapacity = 32 + config.DescriptionCacheCapacity = 0 conn := mustConnect(b, config) defer closeConn(b, conn) @@ -745,9 +646,9 @@ func BenchmarkMultipleQueriesNonBatchPrepareStatementCache(b *testing.B) { func BenchmarkMultipleQueriesNonBatchDescribeStatementCache(b *testing.B) { config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) - config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache { - return stmtcache.New(conn, stmtcache.ModeDescribe, 32) - } + config.DefaultQueryExecMode = pgx.QueryExecModeCacheDescribe + config.StatementCacheCapacity = 0 + config.DescriptionCacheCapacity = 32 conn := mustConnect(b, config) defer closeConn(b, conn) @@ -783,7 +684,9 @@ func benchmarkMultipleQueriesNonBatch(b *testing.B, conn *pgx.Conn, queryCount i func BenchmarkMultipleQueriesBatchNoStatementCache(b *testing.B) { config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) - config.BuildStatementCache = nil + config.DefaultQueryExecMode = pgx.QueryExecModeDescribeExec + config.StatementCacheCapacity = 0 + config.DescriptionCacheCapacity = 0 conn := mustConnect(b, config) defer closeConn(b, conn) @@ -793,9 +696,9 @@ func BenchmarkMultipleQueriesBatchNoStatementCache(b *testing.B) { func BenchmarkMultipleQueriesBatchPrepareStatementCache(b *testing.B) { config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) - config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache { - return stmtcache.New(conn, stmtcache.ModePrepare, 32) - } + config.DefaultQueryExecMode = pgx.QueryExecModeCacheStatement + config.StatementCacheCapacity = 32 + config.DescriptionCacheCapacity = 0 conn := mustConnect(b, config) defer closeConn(b, conn) @@ -805,9 +708,9 @@ func BenchmarkMultipleQueriesBatchPrepareStatementCache(b *testing.B) { func BenchmarkMultipleQueriesBatchDescribeStatementCache(b *testing.B) { config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) - config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache { - return stmtcache.New(conn, stmtcache.ModeDescribe, 32) - } + config.DefaultQueryExecMode = pgx.QueryExecModeCacheDescribe + config.StatementCacheCapacity = 0 + config.DescriptionCacheCapacity = 32 conn := mustConnect(b, config) defer closeConn(b, conn) @@ -918,8 +821,7 @@ func BenchmarkSelectManyRegisteredEnum(b *testing.B) { err = conn.QueryRow(context.Background(), "select oid from pg_type where typname=$1;", "color").Scan(&oid) require.NoError(b, err) - et := pgtype.NewEnumType("color", []string{"blue", "green", "orange"}) - conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: et, Name: "color", OID: oid}) + conn.TypeMap().RegisterType(&pgtype.Type{Name: "color", OID: oid, Codec: &pgtype.EnumCodec{}}) b.ResetTimer() var x, y, z string @@ -1105,73 +1007,6 @@ func BenchmarkSelectRowsScanDecoder(b *testing.B) { } } -func BenchmarkSelectRowsExplicitDecoding(b *testing.B) { - conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(b, conn) - - rowCounts := getSelectRowsCounts(b) - - for _, rowCount := range rowCounts { - b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) { - br := &BenchRowDecoder{} - for i := 0; i < b.N; i++ { - rows, err := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount) - if err != nil { - b.Fatal(err) - } - - for rows.Next() { - rawValues := rows.RawValues() - - err = br.ID.DecodeBinary(conn.ConnInfo(), rawValues[0]) - if err != nil { - b.Fatal(err) - } - - err = br.FirstName.DecodeText(conn.ConnInfo(), rawValues[1]) - if err != nil { - b.Fatal(err) - } - - err = br.LastName.DecodeText(conn.ConnInfo(), rawValues[2]) - if err != nil { - b.Fatal(err) - } - - err = br.Sex.DecodeText(conn.ConnInfo(), rawValues[3]) - if err != nil { - b.Fatal(err) - } - - err = br.BirthDate.DecodeBinary(conn.ConnInfo(), rawValues[4]) - if err != nil { - b.Fatal(err) - } - - err = br.Weight.DecodeBinary(conn.ConnInfo(), rawValues[5]) - if err != nil { - b.Fatal(err) - } - - err = br.Height.DecodeBinary(conn.ConnInfo(), rawValues[6]) - if err != nil { - b.Fatal(err) - } - - err = br.UpdateTime.DecodeBinary(conn.ConnInfo(), rawValues[7]) - if err != nil { - b.Fatal(err) - } - } - - if rows.Err() != nil { - b.Fatal(rows.Err()) - } - } - }) - } -} - func BenchmarkSelectRowsPgConnExecText(b *testing.B) { conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(b, conn) @@ -1285,7 +1120,7 @@ func BenchmarkSelectRowsPgConnExecPrepared(b *testing.B) { } type queryRecorder struct { - conn net.Conn + conn nbconn.Conn writeBuf []byte readCount int } @@ -1301,6 +1136,14 @@ func (qr *queryRecorder) Write(b []byte) (n int, err error) { return qr.conn.Write(b) } +func (qr *queryRecorder) BufferReadUntilBlock() error { + return qr.conn.BufferReadUntilBlock() +} + +func (qr *queryRecorder) Flush() error { + return qr.conn.Flush() +} + func (qr *queryRecorder) Close() error { return qr.conn.Close() } diff --git a/ci/setup_test.bash b/ci/setup_test.bash index c279a7a4..8f3f26f0 100755 --- a/ci/setup_test.bash +++ b/ci/setup_test.bash @@ -13,6 +13,10 @@ then echo "local all postgres trust" > /etc/postgresql/$PGVERSION/main/pg_hba.conf echo "local all all trust" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf echo "host all pgx_md5 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf + echo "host all pgx_pw 127.0.0.1/32 password" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf + echo "hostssl all pgx_ssl 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf + echo "host replication pgx_replication 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf + echo "host pgx_test pgx_replication 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf sudo chmod 777 /etc/postgresql/$PGVERSION/main/postgresql.conf if $(dpkg --compare-versions $PGVERSION ge 9.6) ; then echo "wal_level='logical'" >> /etc/postgresql/$PGVERSION/main/postgresql.conf @@ -24,8 +28,15 @@ then psql -U postgres -c 'create database pgx_test' psql -U postgres pgx_test -c 'create extension hstore' psql -U postgres pgx_test -c 'create domain uint64 as numeric(20,0)' + psql -U postgres -c "create user pgx_ssl SUPERUSER PASSWORD 'secret'" psql -U postgres -c "create user pgx_md5 SUPERUSER PASSWORD 'secret'" + psql -U postgres -c "create user pgx_pw SUPERUSER PASSWORD 'secret'" psql -U postgres -c "create user `whoami`" + psql -U postgres -c "create user pgx_replication with replication password 'secret'" + + # The tricky test user, below, has to actually exist so that it can be used in a test + # of aclitem formatting. It turns out aclitems cannot contain non-existing users/roles. + psql -U postgres -c "create user \" tricky, ' } \"\" \\ test user \" superuser password 'secret'" fi if [[ "${PGVERSION-}" =~ ^cockroach ]] diff --git a/conn.go b/conn.go index 854561e0..a95cff21 100644 --- a/conn.go +++ b/conn.go @@ -8,34 +8,36 @@ import ( "strings" "time" - "github.com/jackc/pgconn" - "github.com/jackc/pgconn/stmtcache" - "github.com/jackc/pgproto3/v2" - "github.com/jackc/pgtype" - "github.com/jackc/pgx/v4/internal/sanitize" + "github.com/jackc/pgx/v5/internal/anynil" + "github.com/jackc/pgx/v5/internal/sanitize" + "github.com/jackc/pgx/v5/internal/stmtcache" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" ) // ConnConfig contains all the options used to establish a connection. It must be created by ParseConfig and // then it can be modified. A manually initialized ConnConfig will cause ConnectConfig to panic. type ConnConfig struct { pgconn.Config - Logger Logger - LogLevel LogLevel + + Tracer QueryTracer // Original connection string that was parsed into config. connString string - // BuildStatementCache creates the stmtcache.Cache implementation for connections created with this config. Set - // to nil to disable automatic prepared statements. - BuildStatementCache BuildStatementCacheFunc + // StatementCacheCapacity is maximum size of the statement cache used when executing a query with "cache_statement" + // query exec mode. + StatementCacheCapacity int - // PreferSimpleProtocol disables implicit prepared statement usage. By default pgx automatically uses the extended - // protocol. This can improve performance due to being able to use the binary format. It also does not rely on client - // side parameter sanitization. However, it does incur two round-trips per query (unless using a prepared statement) - // and may be incompatible proxies such as PGBouncer. Setting PreferSimpleProtocol causes the simple protocol to be - // used by default. The same functionality can be controlled on a per query basis by setting - // QueryExOptions.SimpleProtocol. - PreferSimpleProtocol bool + // DescriptionCacheCapacity is the maximum size of the description cache used when executing a query with + // "cache_describe" query exec mode. + DescriptionCacheCapacity int + + // DefaultQueryExecMode controls the default mode for executing queries. By default pgx uses the extended protocol + // and automatically prepares and caches prepared statements. However, this may be incompatible with proxies such as + // PGBouncer. In this case it may be preferrable to use QueryExecModeExec or QueryExecModeSimpleProtocol. The same + // functionality can be controlled on a per query basis by passing a QueryExecMode as the first query argument. + DefaultQueryExecMode QueryExecMode createdByParseConfig bool // Used to enforce created by ParseConfig rule. } @@ -53,28 +55,29 @@ func (cc *ConnConfig) Copy() *ConnConfig { // ConnString returns the connection string as parsed by pgx.ParseConfig into pgx.ConnConfig. func (cc *ConnConfig) ConnString() string { return cc.connString } -// BuildStatementCacheFunc is a function that can be used to create a stmtcache.Cache implementation for connection. -type BuildStatementCacheFunc func(conn *pgconn.PgConn) stmtcache.Cache - // Conn is a PostgreSQL connection handle. It is not safe for concurrent usage. Use a connection pool to manage access // to multiple database connections from multiple goroutines. type Conn struct { pgConn *pgconn.PgConn config *ConnConfig // config used when establishing this connection preparedStatements map[string]*pgconn.StatementDescription - stmtcache stmtcache.Cache - logger Logger - logLevel LogLevel + statementCache stmtcache.Cache + descriptionCache stmtcache.Cache + + queryTracer QueryTracer + batchTracer BatchTracer + copyFromTracer CopyFromTracer + prepareTracer PrepareTracer notifications []*pgconn.Notification doneChan chan struct{} closedChan chan error - connInfo *pgtype.ConnInfo + typeMap *pgtype.Map wbuf []byte - eqb extendedQueryBuilder + eqb ExtendedQueryBuilder } // Identifier a PostgreSQL identifier or name. Identifiers can be composed of @@ -94,8 +97,8 @@ func (ident Identifier) Sanitize() string { // ErrNoRows occurs when rows are expected but none are returned. var ErrNoRows = errors.New("no rows in result set") -// ErrInvalidLogLevel occurs on attempt to set an invalid log level. -var ErrInvalidLogLevel = errors.New("invalid log level") +var errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache") +var errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache") // Connect establishes a connection with a PostgreSQL server with a connection string. See // pgconn.Connect for details. @@ -110,32 +113,34 @@ func Connect(ctx context.Context, connString string) (*Conn, error) { // ConnectConfig establishes a connection with a PostgreSQL server with a configuration struct. // connConfig must have been created by ParseConfig. func ConnectConfig(ctx context.Context, connConfig *ConnConfig) (*Conn, error) { + // In general this improves safety. In particular avoid the config.Config.OnNotification mutation from affecting other + // connections with the same config. See https://github.com/jackc/pgx/issues/618. + connConfig = connConfig.Copy() + return connect(ctx, connConfig) } // ParseConfig creates a ConnConfig from a connection string. ParseConfig handles all options that pgconn.ParseConfig // does. In addition, it accepts the following options: // +// default_query_exec_mode +// Possible values: "cache_statement", "cache_describe", "describe_exec", "exec", and "simple_protocol". See +// QueryExecMode constant documentation for the meaning of these values. Default: "cache_statement". +// // statement_cache_capacity -// The maximum size of the automatic statement cache. Set to 0 to disable automatic statement caching. Default: 512. +// The maximum size of the statement cache used when executing a query with "cache_statement" query exec mode. +// Default: 512. // -// statement_cache_mode -// Possible values: "prepare" and "describe". "prepare" will create prepared statements on the PostgreSQL server. -// "describe" will use the anonymous prepared statement to describe a statement without creating a statement on the -// server. "describe" is primarily useful when the environment does not allow prepared statements such as when -// running a connection pooler like PgBouncer. Default: "prepare" -// -// prefer_simple_protocol -// Possible values: "true" and "false". Use the simple protocol instead of extended protocol. Default: false +// description_cache_capacity +// The maximum size of the description cache used when executing a query with "cache_describe" query exec mode. +// Default: 512. func ParseConfig(connString string) (*ConnConfig, error) { config, err := pgconn.ParseConfig(connString) if err != nil { return nil, err } - var buildStatementCache BuildStatementCacheFunc statementCacheCapacity := 512 - statementCacheMode := stmtcache.ModePrepare if s, ok := config.RuntimeParams["statement_cache_capacity"]; ok { delete(config.RuntimeParams, "statement_cache_capacity") n, err := strconv.ParseInt(s, 10, 32) @@ -145,85 +150,85 @@ func ParseConfig(connString string) (*ConnConfig, error) { statementCacheCapacity = int(n) } - if s, ok := config.RuntimeParams["statement_cache_mode"]; ok { - delete(config.RuntimeParams, "statement_cache_mode") + descriptionCacheCapacity := 512 + if s, ok := config.RuntimeParams["description_cache_capacity"]; ok { + delete(config.RuntimeParams, "description_cache_capacity") + n, err := strconv.ParseInt(s, 10, 32) + if err != nil { + return nil, fmt.Errorf("cannot parse description_cache_capacity: %w", err) + } + descriptionCacheCapacity = int(n) + } + + defaultQueryExecMode := QueryExecModeCacheStatement + if s, ok := config.RuntimeParams["default_query_exec_mode"]; ok { + delete(config.RuntimeParams, "default_query_exec_mode") switch s { - case "prepare": - statementCacheMode = stmtcache.ModePrepare - case "describe": - statementCacheMode = stmtcache.ModeDescribe + case "cache_statement": + defaultQueryExecMode = QueryExecModeCacheStatement + case "cache_describe": + defaultQueryExecMode = QueryExecModeCacheDescribe + case "describe_exec": + defaultQueryExecMode = QueryExecModeDescribeExec + case "exec": + defaultQueryExecMode = QueryExecModeExec + case "simple_protocol": + defaultQueryExecMode = QueryExecModeSimpleProtocol default: - return nil, fmt.Errorf("invalid statement_cache_mod: %s", s) - } - } - - if statementCacheCapacity > 0 { - buildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache { - return stmtcache.New(conn, statementCacheMode, statementCacheCapacity) - } - } - - preferSimpleProtocol := false - if s, ok := config.RuntimeParams["prefer_simple_protocol"]; ok { - delete(config.RuntimeParams, "prefer_simple_protocol") - if b, err := strconv.ParseBool(s); err == nil { - preferSimpleProtocol = b - } else { - return nil, fmt.Errorf("invalid prefer_simple_protocol: %v", err) + return nil, fmt.Errorf("invalid default_query_exec_mode: %v", err) } } connConfig := &ConnConfig{ - Config: *config, - createdByParseConfig: true, - LogLevel: LogLevelInfo, - BuildStatementCache: buildStatementCache, - PreferSimpleProtocol: preferSimpleProtocol, - connString: connString, + Config: *config, + createdByParseConfig: true, + StatementCacheCapacity: statementCacheCapacity, + DescriptionCacheCapacity: descriptionCacheCapacity, + DefaultQueryExecMode: defaultQueryExecMode, + connString: connString, } return connConfig, nil } +// connect connects to a database. connect takes ownership of config. The caller must not use or access it again. func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) { + if connectTracer, ok := config.Tracer.(ConnectTracer); ok { + ctx = connectTracer.TraceConnectStart(ctx, TraceConnectStartData{ConnConfig: config}) + defer func() { + connectTracer.TraceConnectEnd(ctx, TraceConnectEndData{Conn: c, Err: err}) + }() + } + // Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from // zero values. if !config.createdByParseConfig { panic("config must be created by ParseConfig") } - originalConfig := config - - // This isn't really a deep copy. But it is enough to avoid the config.Config.OnNotification mutation from affecting - // other connections with the same config. See https://github.com/jackc/pgx/issues/618. - { - configCopy := *config - config = &configCopy - } c = &Conn{ - config: originalConfig, - connInfo: pgtype.NewConnInfo(), - logLevel: config.LogLevel, - logger: config.Logger, + config: config, + typeMap: pgtype.NewMap(), + queryTracer: config.Tracer, + } + + if t, ok := c.queryTracer.(BatchTracer); ok { + c.batchTracer = t + } + if t, ok := c.queryTracer.(CopyFromTracer); ok { + c.copyFromTracer = t + } + if t, ok := c.queryTracer.(PrepareTracer); ok { + c.prepareTracer = t } // Only install pgx notification system if no other callback handler is present. if config.Config.OnNotification == nil { config.Config.OnNotification = c.bufferNotifications - } else { - if c.shouldLog(LogLevelDebug) { - c.log(ctx, LogLevelDebug, "pgx notification handler disabled by application supplied OnNotification", map[string]interface{}{"host": config.Config.Host}) - } } - if c.shouldLog(LogLevelInfo) { - c.log(ctx, LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"host": config.Config.Host}) - } c.pgConn, err = pgconn.ConnectConfig(ctx, &config.Config) if err != nil { - if c.shouldLog(LogLevelError) { - c.log(ctx, LogLevelError, "connect failed", map[string]interface{}{"err": err}) - } return nil, err } @@ -232,14 +237,12 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) { c.closedChan = make(chan error) c.wbuf = make([]byte, 0, 1024) - if c.config.BuildStatementCache != nil { - c.stmtcache = c.config.BuildStatementCache(c.pgConn) + if c.config.StatementCacheCapacity > 0 { + c.statementCache = stmtcache.NewLRUCache(c.config.StatementCacheCapacity) } - // Replication connections can't execute the queries to - // populate the c.PgTypes and c.pgsqlAfInet - if _, ok := config.Config.RuntimeParams["replication"]; ok { - return c, nil + if c.config.DescriptionCacheCapacity > 0 { + c.descriptionCache = stmtcache.NewLRUCache(c.config.DescriptionCacheCapacity) } return c, nil @@ -253,9 +256,6 @@ func (c *Conn) Close(ctx context.Context) error { } err := c.pgConn.Close(ctx) - if c.shouldLog(LogLevelInfo) { - c.log(ctx, LogLevelInfo, "closed connection", nil) - } return err } @@ -266,18 +266,23 @@ func (c *Conn) Close(ctx context.Context) error { // name and sql arguments. This allows a code path to Prepare and Query/Exec without // concern for if the statement has already been prepared. func (c *Conn) Prepare(ctx context.Context, name, sql string) (sd *pgconn.StatementDescription, err error) { + if c.prepareTracer != nil { + ctx = c.prepareTracer.TracePrepareStart(ctx, c, TracePrepareStartData{Name: name, SQL: sql}) + } + if name != "" { var ok bool if sd, ok = c.preparedStatements[name]; ok && sd.SQL == sql { + if c.prepareTracer != nil { + c.prepareTracer.TracePrepareEnd(ctx, c, TracePrepareEndData{AlreadyPrepared: true}) + } return sd, nil } } - if c.shouldLog(LogLevelError) { + if c.prepareTracer != nil { defer func() { - if err != nil { - c.log(ctx, LogLevelError, "Prepare failed", map[string]interface{}{"err": err, "name": name, "sql": sql}) - } + c.prepareTracer.TracePrepareEnd(ctx, c, TracePrepareEndData{Err: err}) }() } @@ -339,21 +344,6 @@ func (c *Conn) die(err error) { c.pgConn.Close(ctx) } -func (c *Conn) shouldLog(lvl LogLevel) bool { - return c.logger != nil && c.logLevel >= lvl -} - -func (c *Conn) log(ctx context.Context, lvl LogLevel, msg string, data map[string]interface{}) { - if data == nil { - data = map[string]interface{}{} - } - if c.pgConn != nil && c.pgConn.PID() != 0 { - data["pid"] = c.pgConn.PID() - } - - c.logger.Log(ctx, lvl, msg, data) -} - func quoteIdentifier(s string) string { return `"` + strings.ReplaceAll(s, `"`, `""`) + `"` } @@ -372,87 +362,111 @@ func (c *Conn) Ping(ctx context.Context) error { // is used and the connection must be returned to the same state before any *pgx.Conn methods are again used. func (c *Conn) PgConn() *pgconn.PgConn { return c.pgConn } -// StatementCache returns the statement cache used for this connection. -func (c *Conn) StatementCache() stmtcache.Cache { return c.stmtcache } - -// ConnInfo returns the connection info used for this connection. -func (c *Conn) ConnInfo() *pgtype.ConnInfo { return c.connInfo } +// TypeMap returns the connection info used for this connection. +func (c *Conn) TypeMap() *pgtype.Map { return c.typeMap } // Config returns a copy of config that was used to establish this connection. func (c *Conn) Config() *ConnConfig { return c.config.Copy() } // Exec executes sql. sql can be either a prepared statement name or an SQL string. arguments should be referenced // positionally from the sql string as $1, $2, etc. -func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) { - startTime := time.Now() - - commandTag, err := c.exec(ctx, sql, arguments...) - if err != nil { - if c.shouldLog(LogLevelError) { - endTime := time.Now() - c.log(ctx, LogLevelError, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "err": err, "time": endTime.Sub(startTime)}) - } - return commandTag, err +func (c *Conn) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) { + if c.queryTracer != nil { + ctx = c.queryTracer.TraceQueryStart(ctx, c, TraceQueryStartData{SQL: sql, Args: arguments}) } - if c.shouldLog(LogLevelInfo) { - endTime := time.Now() - c.log(ctx, LogLevelInfo, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "time": endTime.Sub(startTime), "commandTag": commandTag}) + if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil { + return pgconn.CommandTag{}, err + } + + commandTag, err := c.exec(ctx, sql, arguments...) + + if c.queryTracer != nil { + c.queryTracer.TraceQueryEnd(ctx, c, TraceQueryEndData{CommandTag: commandTag, Err: err}) } return commandTag, err } -func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) { - simpleProtocol := c.config.PreferSimpleProtocol +func (c *Conn) exec(ctx context.Context, sql string, arguments ...any) (commandTag pgconn.CommandTag, err error) { + mode := c.config.DefaultQueryExecMode + var queryRewriter QueryRewriter optionLoop: for len(arguments) > 0 { switch arg := arguments[0].(type) { - case QuerySimpleProtocol: - simpleProtocol = bool(arg) + case QueryExecMode: + mode = arg + arguments = arguments[1:] + case QueryRewriter: + queryRewriter = arg arguments = arguments[1:] default: break optionLoop } } + if queryRewriter != nil { + sql, arguments = queryRewriter.RewriteQuery(ctx, c, sql, arguments) + } + + // Always use simple protocol when there are no arguments. + if len(arguments) == 0 { + mode = QueryExecModeSimpleProtocol + } + if sd, ok := c.preparedStatements[sql]; ok { return c.execPrepared(ctx, sd, arguments) } - if simpleProtocol { - return c.execSimpleProtocol(ctx, sql, arguments) - } - - if len(arguments) == 0 { - return c.execSimpleProtocol(ctx, sql, arguments) - } - - if c.stmtcache != nil { - sd, err := c.stmtcache.Get(ctx, sql) - if err != nil { - return nil, err + switch mode { + case QueryExecModeCacheStatement: + if c.statementCache == nil { + return pgconn.CommandTag{}, errDisabledStatementCache + } + sd := c.statementCache.Get(sql) + if sd == nil { + sd, err = c.Prepare(ctx, stmtcache.NextStatementName(), sql) + if err != nil { + return pgconn.CommandTag{}, err + } + c.statementCache.Put(sd) } - if c.stmtcache.Mode() == stmtcache.ModeDescribe { - return c.execParams(ctx, sd, arguments) + return c.execPrepared(ctx, sd, arguments) + case QueryExecModeCacheDescribe: + if c.descriptionCache == nil { + return pgconn.CommandTag{}, errDisabledDescriptionCache + } + sd := c.descriptionCache.Get(sql) + if sd == nil { + sd, err = c.Prepare(ctx, "", sql) + if err != nil { + return pgconn.CommandTag{}, err + } + } + + return c.execParams(ctx, sd, arguments) + case QueryExecModeDescribeExec: + sd, err := c.Prepare(ctx, "", sql) + if err != nil { + return pgconn.CommandTag{}, err } return c.execPrepared(ctx, sd, arguments) + case QueryExecModeExec: + return c.execSQLParams(ctx, sql, arguments) + case QueryExecModeSimpleProtocol: + return c.execSimpleProtocol(ctx, sql, arguments) + default: + return pgconn.CommandTag{}, fmt.Errorf("unknown QueryExecMode: %v", mode) } - - sd, err := c.Prepare(ctx, "", sql) - if err != nil { - return nil, err - } - return c.execPrepared(ctx, sd, arguments) } -func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []interface{}) (commandTag pgconn.CommandTag, err error) { +func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []any) (commandTag pgconn.CommandTag, err error) { if len(arguments) > 0 { sql, err = c.sanitizeForSimpleQuery(sql, arguments...) if err != nil { - return nil, err + return pgconn.CommandTag{}, err } } @@ -464,60 +478,53 @@ func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []i return commandTag, err } -func (c *Conn) execParamsAndPreparedPrefix(sd *pgconn.StatementDescription, arguments []interface{}) error { - if len(sd.ParamOIDs) != len(arguments) { - return fmt.Errorf("expected %d arguments, got %d", len(sd.ParamOIDs), len(arguments)) - } - - c.eqb.Reset() - - args, err := convertDriverValuers(arguments) +func (c *Conn) execParams(ctx context.Context, sd *pgconn.StatementDescription, arguments []any) (pgconn.CommandTag, error) { + err := c.eqb.Build(c.typeMap, sd, arguments) if err != nil { - return err + return pgconn.CommandTag{}, err } - for i := range args { - err = c.eqb.AppendParam(c.connInfo, sd.ParamOIDs[i], args[i]) - if err != nil { - return err - } - } - - for i := range sd.Fields { - c.eqb.AppendResultFormat(c.ConnInfo().ResultFormatCodeForOID(sd.Fields[i].DataTypeOID)) - } - - return nil -} - -func (c *Conn) execParams(ctx context.Context, sd *pgconn.StatementDescription, arguments []interface{}) (pgconn.CommandTag, error) { - err := c.execParamsAndPreparedPrefix(sd, arguments) - if err != nil { - return nil, err - } - - result := c.pgConn.ExecParams(ctx, sd.SQL, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, c.eqb.resultFormats).Read() - c.eqb.Reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. + result := c.pgConn.ExecParams(ctx, sd.SQL, c.eqb.ParamValues, sd.ParamOIDs, c.eqb.ParamFormats, c.eqb.ResultFormats).Read() + c.eqb.reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. return result.CommandTag, result.Err } -func (c *Conn) execPrepared(ctx context.Context, sd *pgconn.StatementDescription, arguments []interface{}) (pgconn.CommandTag, error) { - err := c.execParamsAndPreparedPrefix(sd, arguments) +func (c *Conn) execPrepared(ctx context.Context, sd *pgconn.StatementDescription, arguments []any) (pgconn.CommandTag, error) { + err := c.eqb.Build(c.typeMap, sd, arguments) if err != nil { - return nil, err + return pgconn.CommandTag{}, err } - result := c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats).Read() - c.eqb.Reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. + result := c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats).Read() + c.eqb.reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. return result.CommandTag, result.Err } -func (c *Conn) getRows(ctx context.Context, sql string, args []interface{}) *connRows { - r := &connRows{} +type unknownArgumentTypeQueryExecModeExecError struct { + arg any +} + +func (e *unknownArgumentTypeQueryExecModeExecError) Error() string { + return fmt.Sprintf("cannot use unregistered type %T as query argument in QueryExecModeExec", e.arg) +} + +func (c *Conn) execSQLParams(ctx context.Context, sql string, args []any) (pgconn.CommandTag, error) { + err := c.eqb.Build(c.typeMap, nil, args) + if err != nil { + return pgconn.CommandTag{}, err + } + + result := c.pgConn.ExecParams(ctx, sql, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats).Read() + c.eqb.reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. + return result.CommandTag, result.Err +} + +func (c *Conn) getRows(ctx context.Context, sql string, args []any) *baseRows { + r := &baseRows{} r.ctx = ctx - r.logger = c - r.connInfo = c.connInfo + r.queryTracer = c.queryTracer + r.typeMap = c.typeMap r.startTime = time.Now() r.sql = sql r.args = args @@ -526,8 +533,64 @@ func (c *Conn) getRows(ctx context.Context, sql string, args []interface{}) *con return r } -// QuerySimpleProtocol controls whether the simple or extended protocol is used to send the query. -type QuerySimpleProtocol bool +type QueryExecMode int32 + +const ( + _ QueryExecMode = iota + + // Automatically prepare and cache statements. This uses the extended protocol. Queries are executed in a single + // round trip after the statement is cached. This is the default. + QueryExecModeCacheStatement + + // Cache statement descriptions (i.e. argument and result types) and assume they do not change. This uses the + // extended protocol. Queries are executed in a single round trip after the description is cached. If the database + // schema is modified or the search_path is changed this may result in undetected result decoding errors. + QueryExecModeCacheDescribe + + // Get the statement description on every execution. This uses the extended protocol. Queries require two round trips + // to execute. It does not use prepared statements (allowing usage with most connection poolers) and is safe even + // when the the database schema is modified concurrently. + QueryExecModeDescribeExec + + // Assume the PostgreSQL query parameter types based on the Go type of the arguments. This uses the extended protocol + // with text formatted parameters and results. Queries are executed in a single round trip. Type mappings can be + // registered with pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are + // unregistered or ambigious. e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know + // the PostgreSQL type can use a map[string]string directly as an argument. This mode cannot. + QueryExecModeExec + + // Use the simple protocol. Assume the PostgreSQL query parameter types based on the Go type of the arguments. + // Queries are executed in a single round trip. Type mappings can be registered with + // pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are unregistered or ambigious. + // e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use + // a map[string]string directly as an argument. This mode cannot. + // + // QueryExecModeSimpleProtocol should have the user application visible behavior as QueryExecModeExec with minor + // exceptions such as behavior when multiple result returning queries are erroneously sent in a single string. + // + // QueryExecModeSimpleProtocol uses client side parameter interpolation. All values are quoted and escaped. Prefer + // QueryExecModeExec over QueryExecModeSimpleProtocol whenever possible. In general QueryExecModeSimpleProtocol + // should only be used if connecting to a proxy server, connection pool server, or non-PostgreSQL server that does + // not support the extended protocol. + QueryExecModeSimpleProtocol +) + +func (m QueryExecMode) String() string { + switch m { + case QueryExecModeCacheStatement: + return "cache statement" + case QueryExecModeCacheDescribe: + return "cache describe" + case QueryExecModeDescribeExec: + return "describe exec" + case QueryExecModeExec: + return "exec" + case QueryExecModeSimpleProtocol: + return "simple protocol" + default: + return "invalid" + } +} // QueryResultFormats controls the result format (text=0, binary=1) of a query by result column position. type QueryResultFormats []int16 @@ -535,20 +598,45 @@ type QueryResultFormats []int16 // QueryResultFormatsByOID controls the result format (text=0, binary=1) of a query by the result column OID. type QueryResultFormatsByOID map[uint32]int16 -// Query executes sql with args. It is safe to attempt to read from the returned Rows even if an error is returned. The -// error will be the available in rows.Err() after rows are closed. So it is allowed to ignore the error returned from -// Query and handle it in Rows. +// QueryRewriter rewrites a query when used as the first arguments to a query method. +type QueryRewriter interface { + RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any) +} + +// Query sends a query to the server and returns a Rows to read the results. Only errors encountered sending the query +// and initializing Rows will be returned. Err() on the returned Rows must be checked after the Rows is closed to +// determine if the query executed successfully. // -// Err() on the returned Rows must be checked after the Rows is closed to determine if the query executed successfully -// as some errors can only be detected by reading the entire response. e.g. A divide by zero error on the last row. +// The returned Rows must be closed before the connection can be used again. It is safe to attempt to read from the +// returned Rows even if an error is returned. The error will be the available in rows.Err() after rows are closed. It +// is allowed to ignore the error returned from Query and handle it in Rows. // -// For extra control over how the query is executed, the types QuerySimpleProtocol, QueryResultFormats, and +// It is possible for a query to return one or more rows before encountering an error. In most cases the rows should be +// collected before processing rather than processed while receiving each row. This avoids the possibility of the +// application processing rows from a query that the server rejected. The CollectRows function is useful here. +// +// An implementor of QueryRewriter may be passed as the first element of args. It can rewrite the sql and change or +// replace args. For example, NamedArgs is QueryRewriter that implements named arguments. +// +// For extra control over how the query is executed, the types QueryExecMode, QueryResultFormats, and // QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely // needed. See the documentation for those types for details. -func (c *Conn) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) { +func (c *Conn) Query(ctx context.Context, sql string, args ...any) (Rows, error) { + if c.queryTracer != nil { + ctx = c.queryTracer.TraceQueryStart(ctx, c, TraceQueryStartData{SQL: sql, Args: args}) + } + + if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil { + if c.queryTracer != nil { + c.queryTracer.TraceQueryEnd(ctx, c, TraceQueryEndData{Err: err}) + } + return &baseRows{err: err, closed: true}, err + } + var resultFormats QueryResultFormats var resultFormatsByOID QueryResultFormatsByOID - simpleProtocol := c.config.PreferSimpleProtocol + mode := c.config.DefaultQueryExecMode + var queryRewriter QueryRewriter optionLoop: for len(args) > 0 { @@ -559,20 +647,112 @@ optionLoop: case QueryResultFormatsByOID: resultFormatsByOID = arg args = args[1:] - case QuerySimpleProtocol: - simpleProtocol = bool(arg) + case QueryExecMode: + mode = arg + args = args[1:] + case QueryRewriter: + queryRewriter = arg args = args[1:] default: break optionLoop } } + if queryRewriter != nil { + sql, args = queryRewriter.RewriteQuery(ctx, c, sql, args) + } + + // Bypass any statement caching. + if sql == "" { + mode = QueryExecModeSimpleProtocol + } + + c.eqb.reset() + anynil.NormalizeSlice(args) rows := c.getRows(ctx, sql, args) var err error - sd, ok := c.preparedStatements[sql] + sd, explicitPreparedStatement := c.preparedStatements[sql] + if sd != nil || mode == QueryExecModeCacheStatement || mode == QueryExecModeCacheDescribe || mode == QueryExecModeDescribeExec { + if sd == nil { + switch mode { + case QueryExecModeCacheStatement: + if c.statementCache == nil { + err = errDisabledStatementCache + rows.fatal(err) + return rows, err + } + sd = c.statementCache.Get(sql) + if sd == nil { + sd, err = c.Prepare(ctx, stmtcache.NextStatementName(), sql) + if err != nil { + rows.fatal(err) + return rows, err + } + c.statementCache.Put(sd) + } + case QueryExecModeCacheDescribe: + if c.descriptionCache == nil { + err = errDisabledDescriptionCache + rows.fatal(err) + return rows, err + } + sd = c.descriptionCache.Get(sql) + if sd == nil { + sd, err = c.Prepare(ctx, "", sql) + if err != nil { + rows.fatal(err) + return rows, err + } + c.descriptionCache.Put(sd) + } + case QueryExecModeDescribeExec: + sd, err = c.Prepare(ctx, "", sql) + if err != nil { + rows.fatal(err) + return rows, err + } + } + } - if simpleProtocol && !ok { + if len(sd.ParamOIDs) != len(args) { + rows.fatal(fmt.Errorf("expected %d arguments, got %d", len(sd.ParamOIDs), len(args))) + return rows, rows.err + } + + rows.sql = sd.SQL + + err = c.eqb.Build(c.typeMap, sd, args) + if err != nil { + rows.fatal(err) + return rows, rows.err + } + + if resultFormatsByOID != nil { + resultFormats = make([]int16, len(sd.Fields)) + for i := range resultFormats { + resultFormats[i] = resultFormatsByOID[uint32(sd.Fields[i].DataTypeOID)] + } + } + + if resultFormats == nil { + resultFormats = c.eqb.ResultFormats + } + + if !explicitPreparedStatement && mode == QueryExecModeCacheDescribe { + rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.ParamValues, sd.ParamOIDs, c.eqb.ParamFormats, resultFormats) + } else { + rows.resultReader = c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, resultFormats) + } + } else if mode == QueryExecModeExec { + err := c.eqb.Build(c.typeMap, nil, args) + if err != nil { + rows.fatal(err) + return rows, rows.err + } + + rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats) + } else if mode == QueryExecModeSimpleProtocol { sql, err = c.sanitizeForSimpleQuery(sql, args...) if err != nil { rows.fatal(err) @@ -590,68 +770,13 @@ optionLoop: } return rows, nil - } - - c.eqb.Reset() - - if !ok { - if c.stmtcache != nil { - sd, err = c.stmtcache.Get(ctx, sql) - if err != nil { - rows.fatal(err) - return rows, rows.err - } - } else { - sd, err = c.pgConn.Prepare(ctx, "", sql, nil) - if err != nil { - rows.fatal(err) - return rows, rows.err - } - } - } - if len(sd.ParamOIDs) != len(args) { - rows.fatal(fmt.Errorf("expected %d arguments, got %d", len(sd.ParamOIDs), len(args))) - return rows, rows.err - } - - rows.sql = sd.SQL - - args, err = convertDriverValuers(args) - if err != nil { + } else { + err = fmt.Errorf("unknown QueryExecMode: %v", mode) rows.fatal(err) return rows, rows.err } - for i := range args { - err = c.eqb.AppendParam(c.connInfo, sd.ParamOIDs[i], args[i]) - if err != nil { - rows.fatal(err) - return rows, rows.err - } - } - - if resultFormatsByOID != nil { - resultFormats = make([]int16, len(sd.Fields)) - for i := range resultFormats { - resultFormats[i] = resultFormatsByOID[uint32(sd.Fields[i].DataTypeOID)] - } - } - - if resultFormats == nil { - for i := range sd.Fields { - c.eqb.AppendResultFormat(c.ConnInfo().ResultFormatCodeForOID(sd.Fields[i].DataTypeOID)) - } - - resultFormats = c.eqb.resultFormats - } - - if c.stmtcache != nil && c.stmtcache.Mode() == stmtcache.ModeDescribe && !ok { - rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, resultFormats) - } else { - rows.resultReader = c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.paramValues, c.eqb.paramFormats, resultFormats) - } - - c.eqb.Reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. + c.eqb.reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. return rows, rows.err } @@ -659,179 +784,302 @@ optionLoop: // QueryRow is a convenience wrapper over Query. Any error that occurs while // querying is deferred until calling Scan on the returned Row. That Row will // error with ErrNoRows if no rows are returned. -func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) Row { +func (c *Conn) QueryRow(ctx context.Context, sql string, args ...any) Row { rows, _ := c.Query(ctx, sql, args...) - return (*connRow)(rows.(*connRows)) -} - -// QueryFuncRow is the argument to the QueryFunc callback function. -// -// QueryFuncRow is an interface instead of a struct to allow tests to mock QueryFunc. However, adding a method to an -// interface is technically a breaking change. Because of this the QueryFuncRow interface is partially excluded from -// semantic version requirements. Methods will not be removed or changed, but new methods may be added. -type QueryFuncRow interface { - FieldDescriptions() []pgproto3.FieldDescription - - // RawValues returns the unparsed bytes of the row values. The returned [][]byte is only valid during the current - // function call. However, the underlying byte data is safe to retain a reference to and mutate. - RawValues() [][]byte -} - -// QueryFunc executes sql with args. For each row returned by the query the values will scanned into the elements of -// scans and f will be called. If any row fails to scan or f returns an error the query will be aborted and the error -// will be returned. -func (c *Conn) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { - rows, err := c.Query(ctx, sql, args...) - if err != nil { - return nil, err - } - defer rows.Close() - - for rows.Next() { - err = rows.Scan(scans...) - if err != nil { - return nil, err - } - - err = f(rows) - if err != nil { - return nil, err - } - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return rows.CommandTag(), nil + return (*connRow)(rows.(*baseRows)) } // SendBatch sends all queued queries to the server at once. All queries are run in an implicit transaction unless // explicit transaction control statements are executed. The returned BatchResults must be closed before the connection // is used again. -func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { - startTime := time.Now() +func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) { + if c.batchTracer != nil { + ctx = c.batchTracer.TraceBatchStart(ctx, c, TraceBatchStartData{Batch: b}) + defer func() { + err := br.(interface{ earlyError() error }).earlyError() + if err != nil { + c.batchTracer.TraceBatchEnd(ctx, c, TraceBatchEndData{Err: err}) + } + }() + } - simpleProtocol := c.config.PreferSimpleProtocol + if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} + } + + mode := c.config.DefaultQueryExecMode + + for _, bi := range b.queuedQueries { + var queryRewriter QueryRewriter + sql := bi.query + arguments := bi.arguments + + optionLoop: + for len(arguments) > 0 { + switch arg := arguments[0].(type) { + case QueryRewriter: + queryRewriter = arg + arguments = arguments[1:] + default: + break optionLoop + } + } + + if queryRewriter != nil { + sql, arguments = queryRewriter.RewriteQuery(ctx, c, sql, arguments) + } + + bi.query = sql + bi.arguments = arguments + } + + if mode == QueryExecModeSimpleProtocol { + return c.sendBatchQueryExecModeSimpleProtocol(ctx, b) + } + + // All other modes use extended protocol and thus can use prepared statements. + for _, bi := range b.queuedQueries { + if sd, ok := c.preparedStatements[bi.query]; ok { + bi.sd = sd + } + } + + switch mode { + case QueryExecModeExec: + return c.sendBatchQueryExecModeExec(ctx, b) + case QueryExecModeCacheStatement: + return c.sendBatchQueryExecModeCacheStatement(ctx, b) + case QueryExecModeCacheDescribe: + return c.sendBatchQueryExecModeCacheDescribe(ctx, b) + case QueryExecModeDescribeExec: + return c.sendBatchQueryExecModeDescribeExec(ctx, b) + default: + panic("unknown QueryExecMode") + } +} + +func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batch) *batchResults { var sb strings.Builder - if simpleProtocol { - for i, bi := range b.items { - if i > 0 { - sb.WriteByte(';') - } - sql, err := c.sanitizeForSimpleQuery(bi.query, bi.arguments...) - if err != nil { - return &batchResults{ctx: ctx, conn: c, err: err} - } - sb.WriteString(sql) + for i, bi := range b.queuedQueries { + if i > 0 { + sb.WriteByte(';') } - mrr := c.pgConn.Exec(ctx, sb.String()) - return &batchResults{ - ctx: ctx, - conn: c, - mrr: mrr, - b: b, - ix: 0, + sql, err := c.sanitizeForSimpleQuery(bi.query, bi.arguments...) + if err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} } + sb.WriteString(sql) } - - distinctUnpreparedQueries := map[string]struct{}{} - - for _, bi := range b.items { - if _, ok := c.preparedStatements[bi.query]; ok { - continue - } - distinctUnpreparedQueries[bi.query] = struct{}{} - } - - var stmtCache stmtcache.Cache - if len(distinctUnpreparedQueries) > 0 { - if c.stmtcache != nil && c.stmtcache.Cap() >= len(distinctUnpreparedQueries) { - stmtCache = c.stmtcache - } else { - stmtCache = stmtcache.New(c.pgConn, stmtcache.ModeDescribe, len(distinctUnpreparedQueries)) - } - - for sql, _ := range distinctUnpreparedQueries { - _, err := stmtCache.Get(ctx, sql) - if err != nil { - return &batchResults{ctx: ctx, conn: c, err: err} - } - } + mrr := c.pgConn.Exec(ctx, sb.String()) + return &batchResults{ + ctx: ctx, + conn: c, + mrr: mrr, + b: b, + qqIdx: 0, } +} +func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchResults { batch := &pgconn.Batch{} - for _, bi := range b.items { - c.eqb.Reset() - - sd := c.preparedStatements[bi.query] - if sd == nil { - var err error - sd, err = stmtCache.Get(ctx, bi.query) + for _, bi := range b.queuedQueries { + sd := bi.sd + if sd != nil { + err := c.eqb.Build(c.typeMap, sd, bi.arguments) if err != nil { - return c.logBatchResults(ctx, startTime, &batchResults{ctx: ctx, conn: c, err: err}) + return &batchResults{ctx: ctx, conn: c, err: err} } - } - if len(sd.ParamOIDs) != len(bi.arguments) { - return c.logBatchResults(ctx, startTime, &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("mismatched param and argument count")}) - } - - args, err := convertDriverValuers(bi.arguments) - if err != nil { - return c.logBatchResults(ctx, startTime, &batchResults{ctx: ctx, conn: c, err: err}) - } - - for i := range args { - err = c.eqb.AppendParam(c.connInfo, sd.ParamOIDs[i], args[i]) - if err != nil { - return c.logBatchResults(ctx, startTime, &batchResults{ctx: ctx, conn: c, err: err}) - } - } - - for i := range sd.Fields { - c.eqb.AppendResultFormat(c.ConnInfo().ResultFormatCodeForOID(sd.Fields[i].DataTypeOID)) - } - - if sd.Name == "" { - batch.ExecParams(bi.query, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, c.eqb.resultFormats) + batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats) } else { - batch.ExecPrepared(sd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats) + err := c.eqb.Build(c.typeMap, nil, bi.arguments) + if err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} + } + batch.ExecParams(bi.query, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats) } } - c.eqb.Reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. + c.eqb.reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. mrr := c.pgConn.ExecBatch(ctx, batch) - return c.logBatchResults(ctx, startTime, &batchResults{ - ctx: ctx, - conn: c, - mrr: mrr, - b: b, - ix: 0, - }) + return &batchResults{ + ctx: ctx, + conn: c, + mrr: mrr, + b: b, + qqIdx: 0, + } } -func (c *Conn) logBatchResults(ctx context.Context, startTime time.Time, results *batchResults) BatchResults { - if results.err != nil { - if c.shouldLog(LogLevelError) { - endTime := time.Now() - c.log(ctx, LogLevelError, "SendBatch", map[string]interface{}{"err": results.err, "time": endTime.Sub(startTime)}) +func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) { + if c.statementCache == nil { + return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledStatementCache} + } + + distinctNewQueries := []*pgconn.StatementDescription{} + distinctNewQueriesIdxMap := make(map[string]int) + + for _, bi := range b.queuedQueries { + if bi.sd == nil { + sd := c.statementCache.Get(bi.query) + if sd != nil { + bi.sd = sd + } else { + if idx, present := distinctNewQueriesIdxMap[bi.query]; present { + bi.sd = distinctNewQueries[idx] + } else { + sd = &pgconn.StatementDescription{ + Name: stmtcache.NextStatementName(), + SQL: bi.query, + } + distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries) + distinctNewQueries = append(distinctNewQueries, sd) + bi.sd = sd + } + } } - return results } - if c.shouldLog(LogLevelInfo) { - endTime := time.Now() - c.log(ctx, LogLevelInfo, "SendBatch", map[string]interface{}{"batchLen": results.b.Len(), "time": endTime.Sub(startTime)}) - } - - return results + return c.sendBatchExtendedWithDescription(ctx, b, distinctNewQueries, c.statementCache) } -func (c *Conn) sanitizeForSimpleQuery(sql string, args ...interface{}) (string, error) { +func (c *Conn) sendBatchQueryExecModeCacheDescribe(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) { + if c.descriptionCache == nil { + return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledDescriptionCache} + } + + distinctNewQueries := []*pgconn.StatementDescription{} + distinctNewQueriesIdxMap := make(map[string]int) + + for _, bi := range b.queuedQueries { + if bi.sd == nil { + sd := c.descriptionCache.Get(bi.query) + if sd != nil { + bi.sd = sd + } else { + if idx, present := distinctNewQueriesIdxMap[bi.query]; present { + bi.sd = distinctNewQueries[idx] + } else { + sd = &pgconn.StatementDescription{ + SQL: bi.query, + } + distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries) + distinctNewQueries = append(distinctNewQueries, sd) + bi.sd = sd + } + } + } + } + + return c.sendBatchExtendedWithDescription(ctx, b, distinctNewQueries, c.descriptionCache) +} + +func (c *Conn) sendBatchQueryExecModeDescribeExec(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) { + distinctNewQueries := []*pgconn.StatementDescription{} + distinctNewQueriesIdxMap := make(map[string]int) + + for _, bi := range b.queuedQueries { + if bi.sd == nil { + if idx, present := distinctNewQueriesIdxMap[bi.query]; present { + bi.sd = distinctNewQueries[idx] + } else { + sd := &pgconn.StatementDescription{ + SQL: bi.query, + } + distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries) + distinctNewQueries = append(distinctNewQueries, sd) + bi.sd = sd + } + } + } + + return c.sendBatchExtendedWithDescription(ctx, b, distinctNewQueries, nil) +} + +func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, distinctNewQueries []*pgconn.StatementDescription, sdCache stmtcache.Cache) (pbr *pipelineBatchResults) { + pipeline := c.pgConn.StartPipeline(context.Background()) + defer func() { + if pbr.err != nil { + pipeline.Close() + } + }() + + // Prepare any needed queries + if len(distinctNewQueries) > 0 { + for _, sd := range distinctNewQueries { + pipeline.SendPrepare(sd.Name, sd.SQL, nil) + } + + err := pipeline.Sync() + if err != nil { + return &pipelineBatchResults{ctx: ctx, conn: c, err: err} + } + + for _, sd := range distinctNewQueries { + results, err := pipeline.GetResults() + if err != nil { + return &pipelineBatchResults{ctx: ctx, conn: c, err: err} + } + + resultSD, ok := results.(*pgconn.StatementDescription) + if !ok { + return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected statement description, got %T", results)} + } + + // Fill in the previously empty / pending statement descriptions. + sd.ParamOIDs = resultSD.ParamOIDs + sd.Fields = resultSD.Fields + } + + results, err := pipeline.GetResults() + if err != nil { + return &pipelineBatchResults{ctx: ctx, conn: c, err: err} + } + + _, ok := results.(*pgconn.PipelineSync) + if !ok { + return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected sync, got %T", results)} + } + } + + // Put all statements into the cache. It's fine if it overflows because HandleInvalidated will clean them up later. + if sdCache != nil { + for _, sd := range distinctNewQueries { + sdCache.Put(sd) + } + } + + // Queue the queries. + for _, bi := range b.queuedQueries { + err := c.eqb.Build(c.typeMap, bi.sd, bi.arguments) + if err != nil { + return &pipelineBatchResults{ctx: ctx, conn: c, err: err} + } + + if bi.sd.Name == "" { + pipeline.SendQueryParams(bi.sd.SQL, c.eqb.ParamValues, bi.sd.ParamOIDs, c.eqb.ParamFormats, c.eqb.ResultFormats) + } else { + pipeline.SendQueryPrepared(bi.sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats) + } + } + + err := pipeline.Sync() + if err != nil { + return &pipelineBatchResults{ctx: ctx, conn: c, err: err} + } + + return &pipelineBatchResults{ + ctx: ctx, + conn: c, + pipeline: pipeline, + b: b, + } +} + +func (c *Conn) sanitizeForSimpleQuery(sql string, args ...any) (string, error) { if c.pgConn.ParameterStatus("standard_conforming_strings") != "on" { return "", errors.New("simple protocol queries must be run with standard_conforming_strings=on") } @@ -841,9 +1089,9 @@ func (c *Conn) sanitizeForSimpleQuery(sql string, args ...interface{}) (string, } var err error - valueArgs := make([]interface{}, len(args)) + valueArgs := make([]any, len(args)) for i, a := range args { - valueArgs[i], err = convertSimpleArgument(c.connInfo, a) + valueArgs[i], err = convertSimpleArgument(c.typeMap, a) if err != nil { return "", err } @@ -851,3 +1099,127 @@ func (c *Conn) sanitizeForSimpleQuery(sql string, args ...interface{}) (string, return sanitize.SanitizeSQL(sql, valueArgs...) } + +// LoadType inspects the database for typeName and produces a pgtype.Type suitable for registration. +func (c *Conn) LoadType(ctx context.Context, typeName string) (*pgtype.Type, error) { + var oid uint32 + + err := c.QueryRow(ctx, "select $1::text::regtype::oid;", typeName).Scan(&oid) + if err != nil { + return nil, err + } + + var typtype string + + err = c.QueryRow(ctx, "select typtype::text from pg_type where oid=$1", oid).Scan(&typtype) + if err != nil { + return nil, err + } + + switch typtype { + case "b": // array + elementOID, err := c.getArrayElementOID(ctx, oid) + if err != nil { + return nil, err + } + + dt, ok := c.TypeMap().TypeForOID(elementOID) + if !ok { + return nil, errors.New("array element OID not registered") + } + + return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.ArrayCodec{ElementType: dt}}, nil + case "c": // composite + fields, err := c.getCompositeFields(ctx, oid) + if err != nil { + return nil, err + } + + return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.CompositeCodec{Fields: fields}}, nil + case "e": // enum + return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.EnumCodec{}}, nil + default: + return &pgtype.Type{}, errors.New("unknown typtype") + } +} + +func (c *Conn) getArrayElementOID(ctx context.Context, oid uint32) (uint32, error) { + var typelem uint32 + + err := c.QueryRow(ctx, "select typelem from pg_type where oid=$1", oid).Scan(&typelem) + if err != nil { + return 0, err + } + + return typelem, nil +} + +func (c *Conn) getCompositeFields(ctx context.Context, oid uint32) ([]pgtype.CompositeCodecField, error) { + var typrelid uint32 + + err := c.QueryRow(ctx, "select typrelid from pg_type where oid=$1", oid).Scan(&typrelid) + if err != nil { + return nil, err + } + + var fields []pgtype.CompositeCodecField + var fieldName string + var fieldOID uint32 + rows, _ := c.Query(ctx, `select attname, atttypid +from pg_attribute +where attrelid=$1 +order by attnum`, + typrelid, + ) + _, err = ForEachRow(rows, []any{&fieldName, &fieldOID}, func() error { + dt, ok := c.TypeMap().TypeForOID(fieldOID) + if !ok { + return fmt.Errorf("unknown composite type field OID: %v", fieldOID) + } + fields = append(fields, pgtype.CompositeCodecField{Name: fieldName, Type: dt}) + return nil + }) + if err != nil { + return nil, err + } + + return fields, nil +} + +func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error { + if c.pgConn.TxStatus() != 'I' { + return nil + } + + if c.descriptionCache != nil { + c.descriptionCache.HandleInvalidated() + } + + var invalidatedStatements []*pgconn.StatementDescription + if c.statementCache != nil { + invalidatedStatements = c.statementCache.HandleInvalidated() + } + + if len(invalidatedStatements) == 0 { + return nil + } + + pipeline := c.pgConn.StartPipeline(ctx) + defer pipeline.Close() + + for _, sd := range invalidatedStatements { + pipeline.SendDeallocate(sd.Name) + } + + err := pipeline.Sync() + if err != nil { + return fmt.Errorf("failed to deallocate cached statement(s): %w", err) + } + + err = pipeline.Close() + if err != nil { + return fmt.Errorf("failed to deallocate cached statement(s): %w", err) + } + + return nil +} diff --git a/conn_test.go b/conn_test.go index 467f6ecc..b84093f4 100644 --- a/conn_test.go +++ b/conn_test.go @@ -3,17 +3,16 @@ package pgx_test import ( "bytes" "context" - "log" "os" "strings" "sync" "testing" "time" - "github.com/jackc/pgconn" - "github.com/jackc/pgconn/stmtcache" - "github.com/jackc/pgtype" - "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -83,7 +82,7 @@ func TestConnectWithPreferSimpleProtocol(t *testing.T) { t.Parallel() connConfig := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) - connConfig.PreferSimpleProtocol = true + connConfig.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol conn := mustConnect(t, connConfig) defer closeConn(t, conn) @@ -93,13 +92,8 @@ func TestConnectWithPreferSimpleProtocol(t *testing.T) { var s pgtype.Text err := conn.QueryRow(context.Background(), "select $1::int4", 42).Scan(&s) - if err != nil { - t.Fatal(err) - } - - if s.Get() != "42" { - t.Fatalf(`expected "42", got %v`, s) - } + require.NoError(t, err) + require.Equal(t, pgtype.Text{String: "42", Valid: true}, s) ensureConnValid(t, conn) } @@ -144,91 +138,122 @@ func TestParseConfigExtractsStatementCacheOptions(t *testing.T) { config, err := pgx.ParseConfig("statement_cache_capacity=0") require.NoError(t, err) - require.Nil(t, config.BuildStatementCache) + require.EqualValues(t, 0, config.StatementCacheCapacity) config, err = pgx.ParseConfig("statement_cache_capacity=42") require.NoError(t, err) - require.NotNil(t, config.BuildStatementCache) - c := config.BuildStatementCache(nil) - require.NotNil(t, c) - require.Equal(t, 42, c.Cap()) - require.Equal(t, stmtcache.ModePrepare, c.Mode()) + require.EqualValues(t, 42, config.StatementCacheCapacity) - config, err = pgx.ParseConfig("statement_cache_capacity=42 statement_cache_mode=prepare") + config, err = pgx.ParseConfig("description_cache_capacity=0") require.NoError(t, err) - require.NotNil(t, config.BuildStatementCache) - c = config.BuildStatementCache(nil) - require.NotNil(t, c) - require.Equal(t, 42, c.Cap()) - require.Equal(t, stmtcache.ModePrepare, c.Mode()) + require.EqualValues(t, 0, config.DescriptionCacheCapacity) - config, err = pgx.ParseConfig("statement_cache_capacity=42 statement_cache_mode=describe") + config, err = pgx.ParseConfig("description_cache_capacity=42") require.NoError(t, err) - require.NotNil(t, config.BuildStatementCache) - c = config.BuildStatementCache(nil) - require.NotNil(t, c) - require.Equal(t, 42, c.Cap()) - require.Equal(t, stmtcache.ModeDescribe, c.Mode()) + require.EqualValues(t, 42, config.DescriptionCacheCapacity) + + // default_query_exec_mode + // Possible values: "cache_statement", "cache_describe", "describe_exec", "exec", and "simple_protocol". See + + config, err = pgx.ParseConfig("default_query_exec_mode=cache_statement") + require.NoError(t, err) + require.Equal(t, pgx.QueryExecModeCacheStatement, config.DefaultQueryExecMode) + + config, err = pgx.ParseConfig("default_query_exec_mode=cache_describe") + require.NoError(t, err) + require.Equal(t, pgx.QueryExecModeCacheDescribe, config.DefaultQueryExecMode) + + config, err = pgx.ParseConfig("default_query_exec_mode=describe_exec") + require.NoError(t, err) + require.Equal(t, pgx.QueryExecModeDescribeExec, config.DefaultQueryExecMode) + + config, err = pgx.ParseConfig("default_query_exec_mode=exec") + require.NoError(t, err) + require.Equal(t, pgx.QueryExecModeExec, config.DefaultQueryExecMode) + + config, err = pgx.ParseConfig("default_query_exec_mode=simple_protocol") + require.NoError(t, err) + require.Equal(t, pgx.QueryExecModeSimpleProtocol, config.DefaultQueryExecMode) } -func TestParseConfigExtractsPreferSimpleProtocol(t *testing.T) { +func TestParseConfigExtractsDefaultQueryExecMode(t *testing.T) { t.Parallel() for _, tt := range []struct { connString string - preferSimpleProtocol bool + defaultQueryExecMode pgx.QueryExecMode }{ - {"", false}, - {"prefer_simple_protocol=false", false}, - {"prefer_simple_protocol=0", false}, - {"prefer_simple_protocol=true", true}, - {"prefer_simple_protocol=1", true}, + {"", pgx.QueryExecModeCacheStatement}, + {"default_query_exec_mode=cache_statement", pgx.QueryExecModeCacheStatement}, + {"default_query_exec_mode=cache_describe", pgx.QueryExecModeCacheDescribe}, + {"default_query_exec_mode=describe_exec", pgx.QueryExecModeDescribeExec}, + {"default_query_exec_mode=exec", pgx.QueryExecModeExec}, + {"default_query_exec_mode=simple_protocol", pgx.QueryExecModeSimpleProtocol}, } { config, err := pgx.ParseConfig(tt.connString) require.NoError(t, err) - require.Equalf(t, tt.preferSimpleProtocol, config.PreferSimpleProtocol, "connString: `%s`", tt.connString) - require.Empty(t, config.RuntimeParams["prefer_simple_protocol"]) + require.Equalf(t, tt.defaultQueryExecMode, config.DefaultQueryExecMode, "connString: `%s`", tt.connString) + require.Empty(t, config.RuntimeParams["default_query_exec_mode"]) } } func TestExec(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { - if results := mustExec(t, conn, "create temporary table foo(id integer primary key);"); string(results) != "CREATE TABLE" { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + if results := mustExec(t, conn, "create temporary table foo(id integer primary key);"); results.String() != "CREATE TABLE" { t.Error("Unexpected results from Exec") } // Accept parameters - if results := mustExec(t, conn, "insert into foo(id) values($1)", 1); string(results) != "INSERT 0 1" { + if results := mustExec(t, conn, "insert into foo(id) values($1)", 1); results.String() != "INSERT 0 1" { t.Errorf("Unexpected results from Exec: %v", results) } - if results := mustExec(t, conn, "drop table foo;"); string(results) != "DROP TABLE" { + if results := mustExec(t, conn, "drop table foo;"); results.String() != "DROP TABLE" { t.Error("Unexpected results from Exec") } // Multiple statements can be executed -- last command tag is returned - if results := mustExec(t, conn, "create temporary table foo(id serial primary key); drop table foo;"); string(results) != "DROP TABLE" { + if results := mustExec(t, conn, "create temporary table foo(id serial primary key); drop table foo;"); results.String() != "DROP TABLE" { t.Error("Unexpected results from Exec") } // Can execute longer SQL strings than sharedBufferSize - if results := mustExec(t, conn, strings.Repeat("select 42; ", 1000)); string(results) != "SELECT 1" { + if results := mustExec(t, conn, strings.Repeat("select 42; ", 1000)); results.String() != "SELECT 1" { t.Errorf("Unexpected results from Exec: %v", results) } // Exec no-op which does not return a command tag - if results := mustExec(t, conn, "--;"); string(results) != "" { + if results := mustExec(t, conn, "--;"); results.String() != "" { t.Errorf("Unexpected results from Exec: %v", results) } }) } +type testQueryRewriter struct { + sql string + args []any +} + +func (qr *testQueryRewriter) RewriteQuery(ctx context.Context, conn *pgx.Conn, sql string, args []any) (newSQL string, newArgs []any) { + return qr.sql, qr.args +} + +func TestExecWithQueryRewriter(t *testing.T) { + t.Parallel() + + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + qr := testQueryRewriter{sql: "select $1::int", args: []any{42}} + _, err := conn.Exec(ctx, "should be replaced", &qr) + require.NoError(t, err) + }) +} + func TestExecFailure(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { if _, err := conn.Exec(context.Background(), "selct;"); err == nil { t.Fatal("Expected SQL syntax error") } @@ -244,7 +269,7 @@ func TestExecFailure(t *testing.T) { func TestExecFailureWithArguments(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { _, err := conn.Exec(context.Background(), "selct $1;", 1) if err == nil { t.Fatal("Expected SQL syntax error") @@ -259,7 +284,7 @@ func TestExecFailureWithArguments(t *testing.T) { func TestExecContextWithoutCancelation(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() @@ -267,7 +292,7 @@ func TestExecContextWithoutCancelation(t *testing.T) { if err != nil { t.Fatal(err) } - if string(commandTag) != "CREATE TABLE" { + if commandTag.String() != "CREATE TABLE" { t.Fatalf("Unexpected results from Exec: %v", commandTag) } assert.False(t, pgconn.SafeToRetry(err)) @@ -277,7 +302,7 @@ func TestExecContextWithoutCancelation(t *testing.T) { func TestExecContextFailureWithoutCancelation(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() @@ -299,7 +324,7 @@ func TestExecContextFailureWithoutCancelation(t *testing.T) { func TestExecContextFailureWithoutCancelationWithArguments(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() @@ -322,56 +347,6 @@ func TestExecFailureCloseBefore(t *testing.T) { assert.True(t, pgconn.SafeToRetry(err)) } -func TestExecStatementCacheModes(t *testing.T) { - t.Parallel() - - config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) - - tests := []struct { - name string - buildStatementCache pgx.BuildStatementCacheFunc - }{ - { - name: "disabled", - buildStatementCache: nil, - }, - { - name: "prepare", - buildStatementCache: func(conn *pgconn.PgConn) stmtcache.Cache { - return stmtcache.New(conn, stmtcache.ModePrepare, 32) - }, - }, - { - name: "describe", - buildStatementCache: func(conn *pgconn.PgConn) stmtcache.Cache { - return stmtcache.New(conn, stmtcache.ModeDescribe, 32) - }, - }, - } - - for _, tt := range tests { - func() { - config.BuildStatementCache = tt.buildStatementCache - conn := mustConnect(t, config) - defer closeConn(t, conn) - - commandTag, err := conn.Exec(context.Background(), "select 1") - assert.NoError(t, err, tt.name) - assert.Equal(t, "SELECT 1", string(commandTag), tt.name) - - commandTag, err = conn.Exec(context.Background(), "select 1 union all select 1") - assert.NoError(t, err, tt.name) - assert.Equal(t, "SELECT 2", string(commandTag), tt.name) - - commandTag, err = conn.Exec(context.Background(), "select 1") - assert.NoError(t, err, tt.name) - assert.Equal(t, "SELECT 1", string(commandTag), tt.name) - - ensureConnValid(t, conn) - }() - } -} - func TestExecPerQuerySimpleProtocol(t *testing.T) { t.Parallel() @@ -385,19 +360,19 @@ func TestExecPerQuerySimpleProtocol(t *testing.T) { if err != nil { t.Fatal(err) } - if string(commandTag) != "CREATE TABLE" { + if commandTag.String() != "CREATE TABLE" { t.Fatalf("Unexpected results from Exec: %v", commandTag) } commandTag, err = conn.Exec(ctx, "insert into foo(name) values($1);", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, "bar'; drop table foo;--", ) if err != nil { t.Fatal(err) } - if string(commandTag) != "INSERT 0 1" { + if commandTag.String() != "INSERT 0 1" { t.Fatalf("Unexpected results from Exec: %v", commandTag) } @@ -501,45 +476,15 @@ func TestPrepareIdempotency(t *testing.T) { func TestPrepareStatementCacheModes(t *testing.T) { t.Parallel() - config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + _, err := conn.Prepare(context.Background(), "test", "select $1::text") + require.NoError(t, err) - tests := []struct { - name string - buildStatementCache pgx.BuildStatementCacheFunc - }{ - { - name: "disabled", - buildStatementCache: nil, - }, - { - name: "prepare", - buildStatementCache: func(conn *pgconn.PgConn) stmtcache.Cache { - return stmtcache.New(conn, stmtcache.ModePrepare, 32) - }, - }, - { - name: "describe", - buildStatementCache: func(conn *pgconn.PgConn) stmtcache.Cache { - return stmtcache.New(conn, stmtcache.ModeDescribe, 32) - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - config.BuildStatementCache = tt.buildStatementCache - conn := mustConnect(t, config) - defer closeConn(t, conn) - - _, err := conn.Prepare(context.Background(), "test", "select $1::text") - require.NoError(t, err) - - var s string - err = conn.QueryRow(context.Background(), "test", "hello").Scan(&s) - require.NoError(t, err) - require.Equal(t, "hello", s) - }) - } + var s string + err = conn.QueryRow(context.Background(), "test", "hello").Scan(&s) + require.NoError(t, err) + require.Equal(t, "hello", s) + }) } func TestListenNotify(t *testing.T) { @@ -595,7 +540,7 @@ func TestListenNotifyWhileBusyIsSafe(t *testing.T) { func() { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - skipCockroachDB(t, conn, "Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)") + pgxtest.SkipCockroachDB(t, conn, "Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)") }() listenerDone := make(chan bool) @@ -671,7 +616,7 @@ func TestListenNotifySelfNotification(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - skipCockroachDB(t, conn, "Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)") + pgxtest.SkipCockroachDB(t, conn, "Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)") mustExec(t, conn, "listen self") @@ -706,7 +651,7 @@ func TestFatalRxError(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - skipCockroachDB(t, conn, "Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)") + pgxtest.SkipCockroachDB(t, conn, "Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)") var wg sync.WaitGroup wg.Add(1) @@ -745,7 +690,7 @@ func TestFatalTxError(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - skipCockroachDB(t, conn, "Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)") + pgxtest.SkipCockroachDB(t, conn, "Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)") otherConn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer otherConn.Close(context.Background()) @@ -770,13 +715,13 @@ func TestFatalTxError(t *testing.T) { func TestInsertBoolArray(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { - if results := mustExec(t, conn, "create temporary table foo(spice bool[]);"); string(results) != "CREATE TABLE" { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + if results := mustExec(t, conn, "create temporary table foo(spice bool[]);"); results.String() != "CREATE TABLE" { t.Error("Unexpected results from Exec") } // Accept parameters - if results := mustExec(t, conn, "insert into foo(spice) values($1)", []bool{true, false, true}); string(results) != "INSERT 0 1" { + if results := mustExec(t, conn, "insert into foo(spice) values($1)", []bool{true, false, true}); results.String() != "INSERT 0 1" { t.Errorf("Unexpected results from Exec: %v", results) } }) @@ -785,91 +730,18 @@ func TestInsertBoolArray(t *testing.T) { func TestInsertTimestampArray(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { - if results := mustExec(t, conn, "create temporary table foo(spice timestamp[]);"); string(results) != "CREATE TABLE" { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + if results := mustExec(t, conn, "create temporary table foo(spice timestamp[]);"); results.String() != "CREATE TABLE" { t.Error("Unexpected results from Exec") } // Accept parameters - if results := mustExec(t, conn, "insert into foo(spice) values($1)", []time.Time{time.Unix(1419143667, 0), time.Unix(1419143672, 0)}); string(results) != "INSERT 0 1" { + if results := mustExec(t, conn, "insert into foo(spice) values($1)", []time.Time{time.Unix(1419143667, 0), time.Unix(1419143672, 0)}); results.String() != "INSERT 0 1" { t.Errorf("Unexpected results from Exec: %v", results) } }) } -type testLog struct { - lvl pgx.LogLevel - msg string - data map[string]interface{} -} - -type testLogger struct { - logs []testLog -} - -func (l *testLogger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { - data["ctxdata"] = ctx.Value("ctxdata") - l.logs = append(l.logs, testLog{lvl: level, msg: msg, data: data}) -} - -func TestLogPassesContext(t *testing.T) { - t.Parallel() - - l1 := &testLogger{} - config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) - config.Logger = l1 - - conn := mustConnect(t, config) - defer closeConn(t, conn) - - l1.logs = l1.logs[0:0] // Clear logs written when establishing connection - - ctx := context.WithValue(context.Background(), "ctxdata", "foo") - - if _, err := conn.Exec(ctx, ";"); err != nil { - t.Fatal(err) - } - - if len(l1.logs) != 1 { - t.Fatal("Expected logger to be called once, but it wasn't") - } - - if l1.logs[0].data["ctxdata"] != "foo" { - t.Fatal("Expected context data to be passed to logger, but it wasn't") - } -} - -func TestLoggerFunc(t *testing.T) { - t.Parallel() - - const testMsg = "foo" - - buf := bytes.Buffer{} - logger := log.New(&buf, "", 0) - - createAdapterFn := func(logger *log.Logger) pgx.LoggerFunc { - return func(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { - logger.Printf("%s", testMsg) - } - } - - config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) - config.Logger = createAdapterFn(logger) - - conn := mustConnect(t, config) - defer closeConn(t, conn) - - buf.Reset() // Clear logs written when establishing connection - - if _, err := conn.Exec(context.TODO(), ";"); err != nil { - t.Fatal(err) - } - - if strings.TrimSpace(buf.String()) != testMsg { - t.Errorf("Expected logger function to return '%s', but it was '%s'", testMsg, buf.String()) - } -} - func TestIdentifierSanitize(t *testing.T) { t.Parallel() @@ -911,7 +783,7 @@ func TestIdentifierSanitize(t *testing.T) { } } -func TestConnInitConnInfo(t *testing.T) { +func TestConnInitTypeMap(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) @@ -923,11 +795,11 @@ func TestConnInitConnInfo(t *testing.T) { "text": pgtype.TextOID, } for name, oid := range nameOIDs { - dtByName, ok := conn.ConnInfo().DataTypeForName(name) + dtByName, ok := conn.TypeMap().TypeForName(name) if !ok { t.Fatalf("Expected type named %v to be present", name) } - dtByOID, ok := conn.ConnInfo().DataTypeForOID(oid) + dtByOID, ok := conn.TypeMap().TypeForOID(oid) if !ok { t.Fatalf("Expected type OID %v to be present", oid) } @@ -940,8 +812,8 @@ func TestConnInitConnInfo(t *testing.T) { } func TestUnregisteredTypeUsableAsStringArgumentAndBaseResult(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { - skipCockroachDB(t, conn, "Server does support domain types (https://github.com/cockroachdb/cockroach/issues/27796)") + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "Server does support domain types (https://github.com/cockroachdb/cockroach/issues/27796)") var n uint64 err := conn.QueryRow(context.Background(), "select $1::uint64", "42").Scan(&n) @@ -956,32 +828,30 @@ func TestUnregisteredTypeUsableAsStringArgumentAndBaseResult(t *testing.T) { } func TestDomainType(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { - skipCockroachDB(t, conn, "Server does support domain types (https://github.com/cockroachdb/cockroach/issues/27796)") - - var n uint64 + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "Server does support domain types (https://github.com/cockroachdb/cockroach/issues/27796)") // Domain type uint64 is a PostgreSQL domain of underlying type numeric. - err := conn.QueryRow(context.Background(), "select $1::uint64", uint64(24)).Scan(&n) + // In the extended protocol preparing "select $1::uint64" appears to create a statement that expects a param OID of + // uint64 but a result OID of the underlying numeric. + + var s string + err := conn.QueryRow(context.Background(), "select $1::uint64", "24").Scan(&s) require.NoError(t, err) + require.Equal(t, "24", s) - // A string can be used. But a string cannot be the result because the describe result from the PostgreSQL server gives - // the underlying type of numeric. - err = conn.QueryRow(context.Background(), "select $1::uint64", "42").Scan(&n) - if err != nil { - t.Fatal(err) - } - if n != 42 { - t.Fatalf("Expected n to be 42, but was %v", n) - } - + // Register type var uint64OID uint32 err = conn.QueryRow(context.Background(), "select t.oid from pg_type t where t.typname='uint64';").Scan(&uint64OID) if err != nil { t.Fatalf("did not find uint64 OID, %v", err) } - conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: &pgtype.Numeric{}, Name: "uint64", OID: uint64OID}) + conn.TypeMap().RegisterType(&pgtype.Type{Name: "uint64", OID: uint64OID, Codec: pgtype.NumericCodec{}}) + + var n uint64 + err = conn.QueryRow(context.Background(), "select $1::uint64", uint64(24)).Scan(&n) + require.NoError(t, err) // String is still an acceptable argument after registration err = conn.QueryRow(context.Background(), "select $1::uint64", "7").Scan(&n) @@ -991,15 +861,48 @@ func TestDomainType(t *testing.T) { if n != 7 { t.Fatalf("Expected n to be 7, but was %v", n) } + }) +} - // But a uint64 is acceptable - err = conn.QueryRow(context.Background(), "select $1::uint64", uint64(24)).Scan(&n) - if err != nil { - t.Fatal(err) +func TestLoadTypeSameNameInDifferentSchemas(t *testing.T) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "Server does support composite types (https://github.com/cockroachdb/cockroach/issues/27792)") + + tx, err := conn.Begin(ctx) + require.NoError(t, err) + defer tx.Rollback(ctx) + + _, err = tx.Exec(ctx, `create schema pgx_a; +create type pgx_a.point as (a text, b text); +create schema pgx_b; +create type pgx_b.point as (c text); +`) + require.NoError(t, err) + + // Register types + for _, typename := range []string{"pgx_a.point", "pgx_b.point"} { + // Obviously using conn while a tx is in use and registering a type after the connection has been established are + // really bad practices, but for the sake of convenience we do it in the test here. + dt, err := conn.LoadType(ctx, typename) + require.NoError(t, err) + conn.TypeMap().RegisterType(dt) } - if n != 24 { - t.Fatalf("Expected n to be 24, but was %v", n) + + type aPoint struct { + A string + B string } + + type bPoint struct { + C string + } + + var a aPoint + var b bPoint + err = tx.QueryRow(ctx, `select '(foo,bar)'::pgx_a.point, '(baz)'::pgx_b.point`).Scan(&a, &b) + require.NoError(t, err) + require.Equal(t, aPoint{"foo", "bar"}, a) + require.Equal(t, bPoint{"baz"}, b) }) } @@ -1028,6 +931,7 @@ func TestStmtCacheInvalidationConn(t *testing.T) { rows, err := conn.Query(ctx, getSQL, 1) require.NoError(t, err) rows.Close() + require.NoError(t, rows.Err()) // Now, change the schema of the table out from under the statement, making it invalid. _, err = conn.Exec(ctx, "ALTER TABLE drop_cols DROP COLUMN f1") @@ -1045,10 +949,10 @@ func TestStmtCacheInvalidationConn(t *testing.T) { rows.Close() for _, err := range []error{nextErr, rows.Err()} { if err == nil { - t.Fatal("expected InvalidCachedStatementPlanError: no error") + t.Fatal(`expected "cached plan must not change result type": no error`) } if !strings.Contains(err.Error(), "cached plan must not change result type") { - t.Fatalf("expected InvalidCachedStatementPlanError, got: %s", err.Error()) + t.Fatalf(`expected "cached plan must not change result type", got: "%s"`, err.Error()) } } @@ -1070,6 +974,10 @@ func TestStmtCacheInvalidationTx(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) + if conn.PgConn().ParameterStatus("crdb_version") != "" { + t.Skip("Server has non-standard prepare in errored transaction behavior (https://github.com/cockroachdb/cockroach/issues/84140)") + } + // create a table and fill it with some data _, err := conn.Exec(ctx, ` DROP TABLE IF EXISTS drop_cols; @@ -1092,6 +1000,7 @@ func TestStmtCacheInvalidationTx(t *testing.T) { rows, err := tx.Query(ctx, getSQL, 1) require.NoError(t, err) rows.Close() + require.NoError(t, rows.Err()) // Now, change the schema of the table out from under the statement, making it invalid. _, err = tx.Exec(ctx, "ALTER TABLE drop_cols DROP COLUMN f1") @@ -1109,18 +1018,17 @@ func TestStmtCacheInvalidationTx(t *testing.T) { rows.Close() for _, err := range []error{nextErr, rows.Err()} { if err == nil { - t.Fatal("expected InvalidCachedStatementPlanError: no error") + t.Fatal(`expected "cached plan must not change result type": no error`) } if !strings.Contains(err.Error(), "cached plan must not change result type") { - t.Fatalf("expected InvalidCachedStatementPlanError, got: %s", err.Error()) + t.Fatalf(`expected "cached plan must not change result type", got: "%s"`, err.Error()) } } - rows, err = tx.Query(ctx, getSQL, 1) - require.NoError(t, err) // error does not pop up immediately - rows.Next() + rows, _ = tx.Query(ctx, getSQL, 1) + rows.Close() err = rows.Err() - // Retries within the same transaction are errors (really anything except a rollbakc + // Retries within the same transaction are errors (really anything except a rollback // will be an error in this transaction). require.Error(t, err) rows.Close() @@ -1140,7 +1048,7 @@ func TestStmtCacheInvalidationTx(t *testing.T) { } func TestInsertDurationInterval(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { _, err := conn.Exec(context.Background(), "create temporary table t(duration INTERVAL(0) NOT NULL)") require.NoError(t, err) @@ -1151,3 +1059,34 @@ func TestInsertDurationInterval(t *testing.T) { require.EqualValues(t, 1, n) }) } + +func TestRawValuesUnderlyingMemoryReused(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var buf []byte + + rows, err := conn.Query(ctx, `select 1::int`) + require.NoError(t, err) + + for rows.Next() { + buf = rows.RawValues()[0] + } + + require.NoError(t, rows.Err()) + + original := make([]byte, len(buf)) + copy(original, buf) + + for i := 0; i < 1_000_000; i++ { + rows, err := conn.Query(ctx, `select $1::int`, i) + require.NoError(t, err) + rows.Close() + require.NoError(t, rows.Err()) + + if bytes.Compare(original, buf) != 0 { + return + } + } + + t.Fatal("expected buffer from RawValues to be overwritten by subsequent queries but it was not") + }) +} diff --git a/copy_from.go b/copy_from.go index 49139d05..c8b98c57 100644 --- a/copy_from.go +++ b/copy_from.go @@ -5,20 +5,19 @@ import ( "context" "fmt" "io" - "time" - "github.com/jackc/pgconn" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" + "github.com/jackc/pgx/v5/pgconn" ) // CopyFromRows returns a CopyFromSource interface over the provided rows slice // making it usable by *Conn.CopyFrom. -func CopyFromRows(rows [][]interface{}) CopyFromSource { +func CopyFromRows(rows [][]any) CopyFromSource { return ©FromRows{rows: rows, idx: -1} } type copyFromRows struct { - rows [][]interface{} + rows [][]any idx int } @@ -27,7 +26,7 @@ func (ctr *copyFromRows) Next() bool { return ctr.idx < len(ctr.rows) } -func (ctr *copyFromRows) Values() ([]interface{}, error) { +func (ctr *copyFromRows) Values() ([]any, error) { return ctr.rows[ctr.idx], nil } @@ -37,12 +36,12 @@ func (ctr *copyFromRows) Err() error { // CopyFromSlice returns a CopyFromSource interface over a dynamic func // making it usable by *Conn.CopyFrom. -func CopyFromSlice(length int, next func(int) ([]interface{}, error)) CopyFromSource { +func CopyFromSlice(length int, next func(int) ([]any, error)) CopyFromSource { return ©FromSlice{next: next, idx: -1, len: length} } type copyFromSlice struct { - next func(int) ([]interface{}, error) + next func(int) ([]any, error) idx int len int err error @@ -53,7 +52,7 @@ func (cts *copyFromSlice) Next() bool { return cts.idx < cts.len } -func (cts *copyFromSlice) Values() ([]interface{}, error) { +func (cts *copyFromSlice) Values() ([]any, error) { values, err := cts.next(cts.idx) if err != nil { cts.err = err @@ -73,7 +72,7 @@ type CopyFromSource interface { Next() bool // Values returns the values for the current row. - Values() ([]interface{}, error) + Values() ([]any, error) // Err returns any error that has been encountered by the CopyFromSource. If // this is not nil *Conn.CopyFrom will abort the copy. @@ -89,6 +88,13 @@ type copyFrom struct { } func (ct *copyFrom) run(ctx context.Context) (int64, error) { + if ct.conn.copyFromTracer != nil { + ctx = ct.conn.copyFromTracer.TraceCopyFromStart(ctx, ct.conn, TraceCopyFromStartData{ + TableName: ct.tableName, + ColumnNames: ct.columnNames, + }) + } + quotedTableName := ct.tableName.Sanitize() cbuf := &bytes.Buffer{} for i, cn := range ct.columnNames { @@ -145,24 +151,19 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) { w.Close() }() - startTime := time.Now() - commandTag, err := ct.conn.pgConn.CopyFrom(ctx, r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames)) r.Close() <-doneChan - rowsAffected := commandTag.RowsAffected() - endTime := time.Now() - if err == nil { - if ct.conn.shouldLog(LogLevelInfo) { - ct.conn.log(ctx, LogLevelInfo, "CopyFrom", map[string]interface{}{"tableName": ct.tableName, "columnNames": ct.columnNames, "time": endTime.Sub(startTime), "rowCount": rowsAffected}) - } - } else if ct.conn.shouldLog(LogLevelError) { - ct.conn.log(ctx, LogLevelError, "CopyFrom", map[string]interface{}{"err": err, "tableName": ct.tableName, "columnNames": ct.columnNames, "time": endTime.Sub(startTime)}) + if ct.conn.copyFromTracer != nil { + ct.conn.copyFromTracer.TraceCopyFromEnd(ctx, ct.conn, TraceCopyFromEndData{ + CommandTag: commandTag, + Err: err, + }) } - return rowsAffected, err + return commandTag.RowsAffected(), err } func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (bool, []byte, error) { @@ -178,7 +179,7 @@ func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (b buf = pgio.AppendInt16(buf, int16(len(ct.columnNames))) for i, val := range values { - buf, err = encodePreparedStatementArgument(ct.conn.connInfo, buf, sd.Fields[i].DataTypeOID, val) + buf, err = encodeCopyValue(ct.conn.typeMap, buf, sd.Fields[i].DataTypeOID, val) if err != nil { return false, nil, err } diff --git a/copy_from_test.go b/copy_from_test.go index 20e5b247..49bfcb34 100644 --- a/copy_from_test.go +++ b/copy_from_test.go @@ -8,8 +8,9 @@ import ( "testing" "time" - "github.com/jackc/pgconn" - "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxtest" "github.com/stretchr/testify/require" ) @@ -31,7 +32,7 @@ func TestConnCopyFromSmall(t *testing.T) { tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local) - inputRows := [][]interface{}{ + inputRows := [][]any{ {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime}, {nil, nil, nil, nil, nil, nil, nil}, } @@ -49,7 +50,7 @@ func TestConnCopyFromSmall(t *testing.T) { t.Errorf("Unexpected error for Query: %v", err) } - var outputRows [][]interface{} + var outputRows [][]any for rows.Next() { row, err := rows.Values() if err != nil { @@ -87,13 +88,13 @@ func TestConnCopyFromSliceSmall(t *testing.T) { tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local) - inputRows := [][]interface{}{ + inputRows := [][]any{ {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime}, {nil, nil, nil, nil, nil, nil, nil}, } copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, - pgx.CopyFromSlice(len(inputRows), func(i int) ([]interface{}, error) { + pgx.CopyFromSlice(len(inputRows), func(i int) ([]any, error) { return inputRows[i], nil })) if err != nil { @@ -108,7 +109,7 @@ func TestConnCopyFromSliceSmall(t *testing.T) { t.Errorf("Unexpected error for Query: %v", err) } - var outputRows [][]interface{} + var outputRows [][]any for rows.Next() { row, err := rows.Values() if err != nil { @@ -134,7 +135,7 @@ func TestConnCopyFromLarge(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - skipCockroachDB(t, conn, "Skipping due to known server issue: (https://github.com/cockroachdb/cockroach/issues/52722)") + pgxtest.SkipCockroachDB(t, conn, "Skipping due to known server issue: (https://github.com/cockroachdb/cockroach/issues/52722)") mustExec(t, conn, `create temporary table foo( a int2, @@ -149,10 +150,10 @@ func TestConnCopyFromLarge(t *testing.T) { tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local) - inputRows := [][]interface{}{} + inputRows := [][]any{} for i := 0; i < 10000; i++ { - inputRows = append(inputRows, []interface{}{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime, []byte{111, 111, 111, 111}}) + inputRows = append(inputRows, []any{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime, []byte{111, 111, 111, 111}}) } copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyFromRows(inputRows)) @@ -168,7 +169,7 @@ func TestConnCopyFromLarge(t *testing.T) { t.Errorf("Unexpected error for Query: %v", err) } - var outputRows [][]interface{} + var outputRows [][]any for rows.Next() { row, err := rows.Values() if err != nil { @@ -211,6 +212,14 @@ func TestConnCopyFromEnum(t *testing.T) { _, err = tx.Exec(ctx, `create type fruit as enum ('apple', 'orange', 'grape')`) require.NoError(t, err) + // Obviously using conn while a tx is in use and registering a type after the connection has been established are + // really bad practices, but for the sake of convenience we do it in the test here. + for _, name := range []string{"fruit", "color"} { + typ, err := conn.LoadType(ctx, name) + require.NoError(t, err) + conn.TypeMap().RegisterType(typ) + } + _, err = tx.Exec(ctx, `create table foo( a text, b color, @@ -221,7 +230,7 @@ func TestConnCopyFromEnum(t *testing.T) { )`) require.NoError(t, err) - inputRows := [][]interface{}{ + inputRows := [][]any{ {"abc", "blue", "grape", "orange", "orange", "def"}, {nil, nil, nil, nil, nil, nil}, } @@ -233,7 +242,7 @@ func TestConnCopyFromEnum(t *testing.T) { rows, err := conn.Query(ctx, "select * from foo") require.NoError(t, err) - var outputRows [][]interface{} + var outputRows [][]any for rows.Next() { row, err := rows.Values() require.NoError(t, err) @@ -256,7 +265,7 @@ func TestConnCopyFromJSON(t *testing.T) { defer closeConn(t, conn) for _, typeName := range []string{"json", "jsonb"} { - if _, ok := conn.ConnInfo().DataTypeForName(typeName); !ok { + if _, ok := conn.TypeMap().TypeForName(typeName); !ok { return // No JSON/JSONB type -- must be running against old PostgreSQL } } @@ -266,8 +275,8 @@ func TestConnCopyFromJSON(t *testing.T) { b jsonb )`) - inputRows := [][]interface{}{ - {map[string]interface{}{"foo": "bar"}, map[string]interface{}{"bar": "quz"}}, + inputRows := [][]any{ + {map[string]any{"foo": "bar"}, map[string]any{"bar": "quz"}}, {nil, nil}, } @@ -284,7 +293,7 @@ func TestConnCopyFromJSON(t *testing.T) { t.Errorf("Unexpected error for Query: %v", err) } - var outputRows [][]interface{} + var outputRows [][]any for rows.Next() { row, err := rows.Values() if err != nil { @@ -314,12 +323,12 @@ func (cfs *clientFailSource) Next() bool { return cfs.count < 100 } -func (cfs *clientFailSource) Values() ([]interface{}, error) { +func (cfs *clientFailSource) Values() ([]any, error) { if cfs.count == 3 { cfs.err = fmt.Errorf("client error") return nil, cfs.err } - return []interface{}{make([]byte, 100000)}, nil + return []any{make([]byte, 100000)}, nil } func (cfs *clientFailSource) Err() error { @@ -337,7 +346,7 @@ func TestConnCopyFromFailServerSideMidway(t *testing.T) { b varchar not null )`) - inputRows := [][]interface{}{ + inputRows := [][]any{ {int32(1), "abc"}, {int32(2), nil}, // this row should trigger a failure {int32(3), "def"}, @@ -359,7 +368,7 @@ func TestConnCopyFromFailServerSideMidway(t *testing.T) { t.Errorf("Unexpected error for Query: %v", err) } - var outputRows [][]interface{} + var outputRows [][]any for rows.Next() { row, err := rows.Values() if err != nil { @@ -391,11 +400,11 @@ func (fs *failSource) Next() bool { return fs.count < 100 } -func (fs *failSource) Values() ([]interface{}, error) { +func (fs *failSource) Values() ([]any, error) { if fs.count == 3 { - return []interface{}{nil}, nil + return []any{nil}, nil } - return []interface{}{make([]byte, 100000)}, nil + return []any{make([]byte, 100000)}, nil } func (fs *failSource) Err() error { @@ -408,6 +417,8 @@ func TestConnCopyFromFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) + pgxtest.SkipCockroachDB(t, conn, "Server copy error does not fail fast") + mustExec(t, conn, `create temporary table foo( a bytea not null )`) @@ -436,7 +447,7 @@ func TestConnCopyFromFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) { t.Errorf("Unexpected error for Query: %v", err) } - var outputRows [][]interface{} + var outputRows [][]any for rows.Next() { row, err := rows.Values() if err != nil { @@ -466,11 +477,11 @@ func (fs *slowFailRaceSource) Next() bool { return fs.count < 1000 } -func (fs *slowFailRaceSource) Values() ([]interface{}, error) { +func (fs *slowFailRaceSource) Values() ([]any, error) { if fs.count == 500 { - return []interface{}{nil, nil}, nil + return []any{nil, nil}, nil } - return []interface{}{1, make([]byte, 1000)}, nil + return []any{1, make([]byte, 1000)}, nil } func (fs *slowFailRaceSource) Err() error { @@ -525,7 +536,7 @@ func TestConnCopyFromCopyFromSourceErrorMidway(t *testing.T) { t.Errorf("Unexpected error for Query: %v", err) } - var outputRows [][]interface{} + var outputRows [][]any for rows.Next() { row, err := rows.Values() if err != nil { @@ -554,8 +565,8 @@ func (cfs *clientFinalErrSource) Next() bool { return cfs.count < 5 } -func (cfs *clientFinalErrSource) Values() ([]interface{}, error) { - return []interface{}{make([]byte, 100000)}, nil +func (cfs *clientFinalErrSource) Values() ([]any, error) { + return []any{make([]byte, 100000)}, nil } func (cfs *clientFinalErrSource) Err() error { @@ -585,7 +596,7 @@ func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) { t.Errorf("Unexpected error for Query: %v", err) } - var outputRows [][]interface{} + var outputRows [][]any for rows.Next() { row, err := rows.Values() if err != nil { @@ -604,3 +615,32 @@ func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) { ensureConnValid(t, conn) } + +func TestConnCopyFromAutomaticStringConversion(t *testing.T) { + t.Parallel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a int8 + )`) + + inputRows := [][]interface{}{ + {"42"}, + {"7"}, + {8}, + } + + copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows)) + require.NoError(t, err) + require.EqualValues(t, len(inputRows), copyCount) + + rows, _ := conn.Query(context.Background(), "select * from foo") + nums, err := pgx.CollectRows(rows, pgx.RowTo[int64]) + require.NoError(t, err) + + require.Equal(t, []int64{42, 7, 8}, nums) + + ensureConnValid(t, conn) +} diff --git a/doc.go b/doc.go index 222f9047..497ab660 100644 --- a/doc.go +++ b/doc.go @@ -1,8 +1,9 @@ // Package pgx is a PostgreSQL database driver. /* -pgx provides lower level access to PostgreSQL than the standard database/sql. It remains as similar to the database/sql -interface as possible while providing better speed and access to PostgreSQL specific features. Import -github.com/jackc/pgx/v4/stdlib to use pgx as a database/sql compatible driver. +pgx provides a native PostgreSQL driver and can act as a database/sql driver. The native PostgreSQL interface is similar +to the database/sql interface while providing better speed and access to PostgreSQL specific features. Use +github.com/jackc/pgx/v5/stdlib to use pgx as a database/sql compatible driver. See that package's documentation for +details. Establishing a Connection @@ -12,57 +13,41 @@ The primary way of establishing a connection is with `pgx.Connect`. The database connection string can be in URL or DSN format. Both PostgreSQL settings and pgx settings can be specified here. In addition, a config struct can be created by `ParseConfig` and modified before establishing the connection with -`ConnectConfig`. - - config, err := pgx.ParseConfig(os.Getenv("DATABASE_URL")) - if err != nil { - // ... - } - config.Logger = log15adapter.NewLogger(log.New("module", "pgx")) - - conn, err := pgx.ConnectConfig(context.Background(), config) +`ConnectConfig` to configure settings such as tracing that cannot be configured with a connection string. Connection Pool -`*pgx.Conn` represents a single connection to the database and is not concurrency safe. Use sub-package pgxpool for a -concurrency safe connection pool. +`*pgx.Conn` represents a single connection to the database and is not concurrency safe. Use package +github.com/jackc/pgx/v5/pgxpool for a concurrency safe connection pool. Query Interface -pgx implements Query and Scan in the familiar database/sql style. +pgx implements Query in the familiar database/sql style. However, pgx provides generic functions such as CollectRows and +ForEachRow that are a simpler and safer way of processing rows than manually calling rows.Next(), rows.Scan, and +rows.Err(). - var sum int32 +CollectRows can be used collect all returned rows into a slice. - // Send the query to the server. The returned rows MUST be closed - // before conn can be used again. - rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 10) + rows, _ := conn.Query(context.Background(), "select generate_series(1,$1)", 5) + numbers, err := pgx.CollectRows(rows, pgx.RowTo[int32]) if err != nil { - return err + return err } + // numbers => [1 2 3 4 5] - // rows.Close is called by rows.Next when all rows are read - // or an error occurs in Next or Scan. So it may optionally be - // omitted if nothing in the rows.Next loop can panic. It is - // safe to close rows multiple times. - defer rows.Close() +ForEachRow can be used to execute a callback function for every row. This is often easier than iterating over rows +directly. - // Iterate through the result set - for rows.Next() { - var n int32 - err = rows.Scan(&n) - if err != nil { - return err - } - sum += n + var sum, n int32 + rows, _ := conn.Query(context.Background(), "select generate_series(1,$1)", 10) + _, err := pgx.ForEachRow(rows, []any{&n}, func(pgx.QueryFuncRow) error { + sum += n + return nil + }) + if err != nil { + return err } - // Any errors encountered by rows.Next or rows.Scan will be returned here - if rows.Err() != nil { - return rows.Err() - } - - // No errors found - do something with sum - pgx also implements QueryRow in the same style as database/sql. var name string @@ -82,148 +67,10 @@ Use Exec to execute a query that does not return a result set. return errors.New("No row found to delete") } -QueryFunc can be used to execute a callback function for every row. This is often easier to use than Query. +PostgreSQL Data Types - var sum, n int32 - _, err = conn.QueryFunc( - context.Background(), - "select generate_series(1,$1)", - []interface{}{10}, - []interface{}{&n}, - func(pgx.QueryFuncRow) error { - sum += n - return nil - }, - ) - if err != nil { - return err - } - -Base Type Mapping - -pgx maps between all common base types directly between Go and PostgreSQL. In particular: - - Go PostgreSQL - ----------------------- - string varchar - text - - // Integers are automatically be converted to any other integer type if - // it can be done without overflow or underflow. - int8 - int16 smallint - int32 int - int64 bigint - int - uint8 - uint16 - uint32 - uint64 - uint - - // Floats are strict and do not automatically convert like integers. - float32 float4 - float64 float8 - - time.Time date - timestamp - timestamptz - - []byte bytea - - -Null Mapping - -pgx can map nulls in two ways. The first is package pgtype provides types that have a data field and a status field. -They work in a similar fashion to database/sql. The second is to use a pointer to a pointer. - - var foo pgtype.Varchar - var bar *string - err := conn.QueryRow("select foo, bar from widgets where id=$1", 42).Scan(&foo, &bar) - if err != nil { - return err - } - -Array Mapping - -pgx maps between int16, int32, int64, float32, float64, and string Go slices and the equivalent PostgreSQL array type. -Go slices of native types do not support nulls, so if a PostgreSQL array that contains a null is read into a native Go -slice an error will occur. The pgtype package includes many more array types for PostgreSQL types that do not directly -map to native Go types. - -JSON and JSONB Mapping - -pgx includes built-in support to marshal and unmarshal between Go types and the PostgreSQL JSON and JSONB. - -Inet and CIDR Mapping - -pgx encodes from net.IPNet to and from inet and cidr PostgreSQL types. In addition, as a convenience pgx will encode -from a net.IP; it will assume a /32 netmask for IPv4 and a /128 for IPv6. - -Custom Type Support - -pgx includes support for the common data types like integers, floats, strings, dates, and times that have direct -mappings between Go and SQL. In addition, pgx uses the github.com/jackc/pgtype library to support more types. See -documention for that library for instructions on how to implement custom types. - -See example_custom_type_test.go for an example of a custom type for the PostgreSQL point type. - -pgx also includes support for custom types implementing the database/sql.Scanner and database/sql/driver.Valuer -interfaces. - -If pgx does cannot natively encode a type and that type is a renamed type (e.g. type MyTime time.Time) pgx will attempt -to encode the underlying type. While this is usually desired behavior it can produce surprising behavior if one the -underlying type and the renamed type each implement database/sql interfaces and the other implements pgx interfaces. It -is recommended that this situation be avoided by implementing pgx interfaces on the renamed type. - -Composite types and row values - -Row values and composite types are represented as pgtype.Record (https://pkg.go.dev/github.com/jackc/pgtype?tab=doc#Record). -It is possible to get values of your custom type by implementing DecodeBinary interface. Decoding into -pgtype.Record first can simplify process by avoiding dealing with raw protocol directly. - -For example: - - type MyType struct { - a int // NULL will cause decoding error - b *string // there can be NULL in this position in SQL - } - - func (t *MyType) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { - r := pgtype.Record{ - Fields: []pgtype.Value{&pgtype.Int4{}, &pgtype.Text{}}, - } - - if err := r.DecodeBinary(ci, src); err != nil { - return err - } - - if r.Status != pgtype.Present { - return errors.New("BUG: decoding should not be called on NULL value") - } - - a := r.Fields[0].(*pgtype.Int4) - b := r.Fields[1].(*pgtype.Text) - - // type compatibility is checked by AssignTo - // only lossless assignments will succeed - if err := a.AssignTo(&t.a); err != nil { - return err - } - - // AssignTo also deals with null value handling - if err := b.AssignTo(&t.b); err != nil { - return err - } - return nil - } - - result := MyType{} - err := conn.QueryRow(context.Background(), "select row(1, 'foo'::text)", pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&r) - -Raw Bytes Mapping - -[]byte passed as arguments to Query, QueryRow, and Exec are passed unmodified to PostgreSQL. +The package pgtype provides extensive and customizable support for converting Go values to and from PostgreSQL values +including array and composite types. See that package's documentation for details. Transactions @@ -250,12 +97,13 @@ Transactions are started by calling Begin. The Tx returned from Begin also implements the Begin method. This can be used to implement pseudo nested transactions. These are internally implemented with savepoints. -Use BeginTx to control the transaction mode. +Use BeginTx to control the transaction mode. BeginTx also can be used to ensure a new transaction is created instead of +a pseudo nested transaction. -BeginFunc and BeginTxFunc are variants that begin a transaction, execute a function, and commit or rollback the +BeginFunc and BeginTxFunc are functions that begin a transaction, execute a function, and commit or rollback the transaction depending on the return value of the function. These can be simpler and less error prone to use. - err = conn.BeginFunc(context.Background(), func(tx pgx.Tx) error { + err = pgx.BeginFunc(context.Background(), conn, func(tx pgx.Tx) error { _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)") return err }) @@ -273,10 +121,10 @@ for information on how to customize or disable the statement cache. Copy Protocol Use CopyFrom to efficiently insert multiple rows at a time using the PostgreSQL copy protocol. CopyFrom accepts a -CopyFromSource interface. If the data is already in a [][]interface{} use CopyFromRows to wrap it in a CopyFromSource -interface. Or implement CopyFromSource to avoid buffering the entire data set in memory. +CopyFromSource interface. If the data is already in a [][]any use CopyFromRows to wrap it in a CopyFromSource interface. +Or implement CopyFromSource to avoid buffering the entire data set in memory. - rows := [][]interface{}{ + rows := [][]any{ {"John", "Smith", int32(36)}, {"Jane", "Doe", int32(29)}, } @@ -299,8 +147,8 @@ When you already have a typed array using CopyFromSlice can be more convenient. context.Background(), pgx.Identifier{"people"}, []string{"first_name", "last_name", "age"}, - pgx.CopyFromSlice(len(rows), func(i int) ([]interface{}, error) { - return []interface{}{rows[i].FirstName, rows[i].LastName, rows[i].Age}, nil + pgx.CopyFromSlice(len(rows), func(i int) ([]any, error) { + return []any{rows[i].FirstName, rows[i].LastName, rows[i].Age}, nil }), ) @@ -321,20 +169,22 @@ notification is received or the context is canceled. } -Logging +Tracing and Logging -pgx defines a simple logger interface. Connections optionally accept a logger that satisfies this interface. Set -LogLevel to control logging verbosity. Adapters for github.com/inconshreveable/log15, github.com/sirupsen/logrus, -go.uber.org/zap, github.com/rs/zerolog, and the testing log are provided in the log directory. +pgx supports tracing by setting ConnConfig.Tracer. + +In addition, the tracelog package provides the TraceLog type which lets a traditional logger act as a Tracer. + +For debug tracing of the actual PostgreSQL wire protocol messages see github.com/jackc/pgx/v5/pgproto3. Lower Level PostgreSQL Functionality -pgx is implemented on top of github.com/jackc/pgconn a lower level PostgreSQL driver. The Conn.PgConn() method can be -used to access this lower layer. +github.com/jackc/pgx/v5/pgconn contains a lower level PostgreSQL driver roughly at the level of libpq. pgx.Conn in +implemented on top of pgconn. The Conn.PgConn() method can be used to access this lower layer. PgBouncer -pgx is compatible with PgBouncer in two modes. One is when the connection has a statement cache in "describe" mode. The -other is when the connection is using the simple protocol. This can be set with the PreferSimpleProtocol config option. +By default pgx automatically uses prepared statements. Prepared statements are incompaptible with PgBouncer. This can be +disabled by setting a different QueryExecMode in ConnConfig.DefaultQueryExecMode. */ package pgx diff --git a/example_custom_type_test.go b/example_custom_type_test.go deleted file mode 100644 index 34331f5b..00000000 --- a/example_custom_type_test.go +++ /dev/null @@ -1,115 +0,0 @@ -package pgx_test - -import ( - "context" - "fmt" - "os" - "regexp" - "strconv" - - "github.com/jackc/pgtype" - "github.com/jackc/pgx/v4" -) - -var pointRegexp *regexp.Regexp = regexp.MustCompile(`^\((.*),(.*)\)$`) - -// Point represents a point that may be null. -type Point struct { - X, Y float64 // Coordinates of point - Status pgtype.Status -} - -func (dst *Point) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to Point", src) -} - -func (dst *Point) Get() interface{} { - switch dst.Status { - case pgtype.Present: - return dst - case pgtype.Null: - return nil - default: - return dst.Status - } -} - -func (src *Point) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) -} - -func (dst *Point) DecodeText(ci *pgtype.ConnInfo, src []byte) error { - if src == nil { - *dst = Point{Status: pgtype.Null} - return nil - } - - s := string(src) - match := pointRegexp.FindStringSubmatch(s) - if match == nil { - return fmt.Errorf("Received invalid point: %v", s) - } - - x, err := strconv.ParseFloat(match[1], 64) - if err != nil { - return fmt.Errorf("Received invalid point: %v", s) - } - y, err := strconv.ParseFloat(match[2], 64) - if err != nil { - return fmt.Errorf("Received invalid point: %v", s) - } - - *dst = Point{X: x, Y: y, Status: pgtype.Present} - - return nil -} - -func (src *Point) String() string { - if src.Status == pgtype.Null { - return "null point" - } - - return fmt.Sprintf("%.1f, %.1f", src.X, src.Y) -} - -func Example_CustomType() { - conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - if err != nil { - fmt.Printf("Unable to establish connection: %v", err) - return - } - defer conn.Close(context.Background()) - - if conn.PgConn().ParameterStatus("crdb_version") != "" { - // Skip test / example when running on CockroachDB which doesn't support the point type. Since an example can't be - // skipped fake success instead. - fmt.Println("null point") - fmt.Println("1.5, 2.5") - return - } - - // Override registered handler for point - conn.ConnInfo().RegisterDataType(pgtype.DataType{ - Value: &Point{}, - Name: "point", - OID: 600, - }) - - p := &Point{} - err = conn.QueryRow(context.Background(), "select null::point").Scan(p) - if err != nil { - fmt.Println(err) - return - } - fmt.Println(p) - - err = conn.QueryRow(context.Background(), "select point(1.5,2.5)").Scan(p) - if err != nil { - fmt.Println(err) - return - } - fmt.Println(p) - // Output: - // null point - // 1.5, 2.5 -} diff --git a/examples/chat/main.go b/examples/chat/main.go index 6be4ee1c..5adbb3b6 100644 --- a/examples/chat/main.go +++ b/examples/chat/main.go @@ -6,14 +6,14 @@ import ( "fmt" "os" - "github.com/jackc/pgx/v4/pgxpool" + "github.com/jackc/pgx/v5/pgxpool" ) var pool *pgxpool.Pool func main() { var err error - pool, err = pgxpool.Connect(context.Background(), os.Getenv("DATABASE_URL")) + pool, err = pgxpool.New(context.Background(), os.Getenv("DATABASE_URL")) if err != nil { fmt.Fprintln(os.Stderr, "Unable to connect to database:", err) os.Exit(1) diff --git a/examples/todo/main.go b/examples/todo/main.go index 9aa8c1cb..6c644ede 100644 --- a/examples/todo/main.go +++ b/examples/todo/main.go @@ -6,7 +6,7 @@ import ( "os" "strconv" - "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v5" ) var conn *pgx.Conn diff --git a/examples/url_shortener/main.go b/examples/url_shortener/main.go index c5e87eb3..12195922 100644 --- a/examples/url_shortener/main.go +++ b/examples/url_shortener/main.go @@ -3,13 +3,12 @@ package main import ( "context" "io/ioutil" + "log" "net/http" "os" - "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/log/log15adapter" - "github.com/jackc/pgx/v4/pgxpool" - log "gopkg.in/inconshreveable/log15.v2" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" ) var db *pgxpool.Pool @@ -71,28 +70,21 @@ func urlHandler(w http.ResponseWriter, req *http.Request) { } func main() { - logger := log15adapter.NewLogger(log.New("module", "pgx")) - poolConfig, err := pgxpool.ParseConfig(os.Getenv("DATABASE_URL")) if err != nil { - log.Crit("Unable to parse DATABASE_URL", "error", err) - os.Exit(1) + log.Fatalln("Unable to parse DATABASE_URL:", err) } - poolConfig.ConnConfig.Logger = logger - - db, err = pgxpool.ConnectConfig(context.Background(), poolConfig) + db, err = pgxpool.NewWithConfig(context.Background(), poolConfig) if err != nil { - log.Crit("Unable to create connection pool", "error", err) - os.Exit(1) + log.Fatalln("Unable to create connection pool:", err) } http.HandleFunc("/", urlHandler) - log.Info("Starting URL shortener on localhost:8080") + log.Println("Starting URL shortener on localhost:8080") err = http.ListenAndServe("localhost:8080", nil) if err != nil { - log.Crit("Unable to start web server", "error", err) - os.Exit(1) + log.Fatalln("Unable to start web server:", err) } } diff --git a/extended_query_builder.go b/extended_query_builder.go index d06f63fd..b0c0e02b 100644 --- a/extended_query_builder.go +++ b/extended_query_builder.go @@ -1,69 +1,118 @@ package pgx import ( - "database/sql/driver" "fmt" - "reflect" - "github.com/jackc/pgtype" + "github.com/jackc/pgx/v5/internal/anynil" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" ) -type extendedQueryBuilder struct { - paramValues [][]byte +// ExtendedQueryBuilder is used to choose the parameter formats, to format the parameters and to choose the result +// formats for an extended query. +type ExtendedQueryBuilder struct { + ParamValues [][]byte paramValueBytes []byte - paramFormats []int16 - resultFormats []int16 + ParamFormats []int16 + ResultFormats []int16 } -func (eqb *extendedQueryBuilder) AppendParam(ci *pgtype.ConnInfo, oid uint32, arg interface{}) error { - f := chooseParameterFormatCode(ci, oid, arg) - eqb.paramFormats = append(eqb.paramFormats, f) +// Build sets ParamValues, ParamFormats, and ResultFormats for use with *PgConn.ExecParams or *PgConn.ExecPrepared. If +// sd is nil then QueryExecModeExec behavior will be used. +func (eqb *ExtendedQueryBuilder) Build(m *pgtype.Map, sd *pgconn.StatementDescription, args []any) error { + eqb.reset() - v, err := eqb.encodeExtendedParamValue(ci, oid, f, arg) - if err != nil { - return err + anynil.NormalizeSlice(args) + + if sd == nil { + return eqb.appendParamsForQueryExecModeExec(m, args) + } + + if len(sd.ParamOIDs) != len(args) { + return fmt.Errorf("mismatched param and argument count") + } + + for i := range args { + err := eqb.appendParam(m, sd.ParamOIDs[i], -1, args[i]) + if err != nil { + err = fmt.Errorf("failed to encode args[%d]: %v", i, err) + return err + } + } + + for i := range sd.Fields { + eqb.appendResultFormat(m.FormatCodeForOID(sd.Fields[i].DataTypeOID)) } - eqb.paramValues = append(eqb.paramValues, v) return nil } -func (eqb *extendedQueryBuilder) AppendResultFormat(f int16) { - eqb.resultFormats = append(eqb.resultFormats, f) +// appendParam appends a parameter to the query. format may be -1 to automatically choose the format. If arg is nil it +// must be an untyped nil. +func (eqb *ExtendedQueryBuilder) appendParam(m *pgtype.Map, oid uint32, format int16, arg any) error { + if format == -1 { + preferredFormat := eqb.chooseParameterFormatCode(m, oid, arg) + preferredErr := eqb.appendParam(m, oid, preferredFormat, arg) + if preferredErr == nil { + return nil + } + + var otherFormat int16 + if preferredFormat == TextFormatCode { + otherFormat = BinaryFormatCode + } else { + otherFormat = TextFormatCode + } + + otherErr := eqb.appendParam(m, oid, otherFormat, arg) + if otherErr == nil { + return nil + } + + return preferredErr // return the error from the preferred format + } + + v, err := eqb.encodeExtendedParamValue(m, oid, format, arg) + if err != nil { + return err + } + + eqb.ParamFormats = append(eqb.ParamFormats, format) + eqb.ParamValues = append(eqb.ParamValues, v) + + return nil } -// Reset readies eqb to build another query. -func (eqb *extendedQueryBuilder) Reset() { - eqb.paramValues = eqb.paramValues[0:0] - eqb.paramValueBytes = eqb.paramValueBytes[0:0] - eqb.paramFormats = eqb.paramFormats[0:0] - eqb.resultFormats = eqb.resultFormats[0:0] +// appendResultFormat appends a result format to the query. +func (eqb *ExtendedQueryBuilder) appendResultFormat(format int16) { + eqb.ResultFormats = append(eqb.ResultFormats, format) +} - if cap(eqb.paramValues) > 64 { - eqb.paramValues = make([][]byte, 0, 64) +// reset readies eqb to build another query. +func (eqb *ExtendedQueryBuilder) reset() { + eqb.ParamValues = eqb.ParamValues[0:0] + eqb.paramValueBytes = eqb.paramValueBytes[0:0] + eqb.ParamFormats = eqb.ParamFormats[0:0] + eqb.ResultFormats = eqb.ResultFormats[0:0] + + if cap(eqb.ParamValues) > 64 { + eqb.ParamValues = make([][]byte, 0, 64) } if cap(eqb.paramValueBytes) > 256 { eqb.paramValueBytes = make([]byte, 0, 256) } - if cap(eqb.paramFormats) > 64 { - eqb.paramFormats = make([]int16, 0, 64) + if cap(eqb.ParamFormats) > 64 { + eqb.ParamFormats = make([]int16, 0, 64) } - if cap(eqb.resultFormats) > 64 { - eqb.resultFormats = make([]int16, 0, 64) + if cap(eqb.ResultFormats) > 64 { + eqb.ResultFormats = make([]int16, 0, 64) } } -func (eqb *extendedQueryBuilder) encodeExtendedParamValue(ci *pgtype.ConnInfo, oid uint32, formatCode int16, arg interface{}) ([]byte, error) { - if arg == nil { - return nil, nil - } - - refVal := reflect.ValueOf(arg) - argIsPtr := refVal.Kind() == reflect.Ptr - - if argIsPtr && refVal.IsNil() { +func (eqb *ExtendedQueryBuilder) encodeExtendedParamValue(m *pgtype.Map, oid uint32, formatCode int16, arg any) ([]byte, error) { + if anynil.Is(arg) { return nil, nil } @@ -71,91 +120,85 @@ func (eqb *extendedQueryBuilder) encodeExtendedParamValue(ci *pgtype.ConnInfo, o eqb.paramValueBytes = make([]byte, 0, 128) } - var err error - var buf []byte pos := len(eqb.paramValueBytes) - if arg, ok := arg.(string); ok { - return []byte(arg), nil + buf, err := m.Encode(oid, formatCode, arg, eqb.paramValueBytes) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + eqb.paramValueBytes = buf + return eqb.paramValueBytes[pos:], nil +} + +// chooseParameterFormatCode determines the correct format code for an +// argument to a prepared statement. It defaults to TextFormatCode if no +// determination can be made. +func (eqb *ExtendedQueryBuilder) chooseParameterFormatCode(m *pgtype.Map, oid uint32, arg any) int16 { + switch arg.(type) { + case string, *string: + return TextFormatCode } - if formatCode == TextFormatCode { - if arg, ok := arg.(pgtype.TextEncoder); ok { - buf, err = arg.EncodeText(ci, eqb.paramValueBytes) + return m.FormatCodeForOID(oid) +} + +// appendParamsForQueryExecModeExec appends the args to eqb. +// +// Parameters must be encoded in the text format because of differences in type conversion between timestamps and +// dates. In QueryExecModeExec we don't know what the actual PostgreSQL type is. To determine the type we use the +// Go type to OID type mapping registered by RegisterDefaultPgType. However, the Go time.Time represents both +// PostgreSQL timestamp[tz] and date. To use the binary format we would need to also specify what the PostgreSQL +// type OID is. But that would mean telling PostgreSQL that we have sent a timestamp[tz] when what is needed is a date. +// This means that the value is converted from text to timestamp[tz] to date. This means it does a time zone conversion +// before converting it to date. This means that dates can be shifted by one day. In text format without that double +// type conversion it takes the date directly and ignores time zone (i.e. it works). +// +// Given that the whole point of QueryExecModeExec is to operate without having to know the PostgreSQL types there is +// no way to safely use binary or to specify the parameter OIDs. +func (eqb *ExtendedQueryBuilder) appendParamsForQueryExecModeExec(m *pgtype.Map, args []any) error { + for _, arg := range args { + if arg == nil { + err := eqb.appendParam(m, 0, TextFormatCode, arg) if err != nil { - return nil, err + return err } - if buf == nil { - return nil, nil - } - eqb.paramValueBytes = buf - return eqb.paramValueBytes[pos:], nil - } - } else if formatCode == BinaryFormatCode { - if arg, ok := arg.(pgtype.BinaryEncoder); ok { - buf, err = arg.EncodeBinary(ci, eqb.paramValueBytes) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - eqb.paramValueBytes = buf - return eqb.paramValueBytes[pos:], nil - } - } - - if argIsPtr { - // We have already checked that arg is not pointing to nil, - // so it is safe to dereference here. - arg = refVal.Elem().Interface() - return eqb.encodeExtendedParamValue(ci, oid, formatCode, arg) - } - - if dt, ok := ci.DataTypeForOID(oid); ok { - value := dt.Value - err := value.Set(arg) - if err != nil { - { - if arg, ok := arg.(driver.Valuer); ok { - v, err := callValuerValue(arg) + } else { + dt, ok := m.TypeForValue(arg) + if !ok { + var tv pgtype.TextValuer + if tv, ok = arg.(pgtype.TextValuer); ok { + t, err := tv.TextValue() if err != nil { - return nil, err + return err + } + + dt, ok = m.TypeForOID(pgtype.TextOID) + if ok { + arg = t } - return eqb.encodeExtendedParamValue(ci, oid, formatCode, v) } } - - return nil, err - } - - return eqb.encodeExtendedParamValue(ci, oid, formatCode, value) - } - - // There is no data type registered for the destination OID, but maybe there is data type registered for the arg - // type. If so use it's text encoder (if available). - if dt, ok := ci.DataTypeForValue(arg); ok { - value := dt.Value - if textEncoder, ok := value.(pgtype.TextEncoder); ok { - err := value.Set(arg) + if !ok { + var str fmt.Stringer + if str, ok = arg.(fmt.Stringer); ok { + dt, ok = m.TypeForOID(pgtype.TextOID) + if ok { + arg = str.String() + } + } + } + if !ok { + return &unknownArgumentTypeQueryExecModeExecError{arg: arg} + } + err := eqb.appendParam(m, dt.OID, TextFormatCode, arg) if err != nil { - return nil, err + return err } - - buf, err = textEncoder.EncodeText(ci, eqb.paramValueBytes) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - eqb.paramValueBytes = buf - return eqb.paramValueBytes[pos:], nil } } - if strippedArg, ok := stripNamedType(&refVal); ok { - return eqb.encodeExtendedParamValue(ci, oid, formatCode, strippedArg) - } - return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) + return nil } diff --git a/go.mod b/go.mod index 37574c2e..5b7109ac 100644 --- a/go.mod +++ b/go.mod @@ -1,21 +1,20 @@ -module github.com/jackc/pgx/v4 +module github.com/jackc/pgx/v5 -go 1.13 +go 1.18 require ( - github.com/Masterminds/semver/v3 v3.1.1 - github.com/cockroachdb/apd v1.1.0 - github.com/go-kit/log v0.1.0 - github.com/gofrs/uuid v4.0.0+incompatible - github.com/jackc/pgconn v1.13.0 - github.com/jackc/pgio v1.0.0 - github.com/jackc/pgproto3/v2 v2.3.1 - github.com/jackc/pgtype v1.12.0 - github.com/jackc/puddle v1.3.0 - github.com/rs/zerolog v1.15.0 - github.com/shopspring/decimal v1.2.0 - github.com/sirupsen/logrus v1.4.2 + github.com/jackc/pgpassfile v1.0.0 + github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b + github.com/jackc/puddle/v2 v2.0.0 github.com/stretchr/testify v1.8.0 - go.uber.org/zap v1.13.0 - gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec + golang.org/x/crypto v0.0.0-20220829220503-c86fa9a7ed90 + golang.org/x/text v0.3.7 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/kr/pretty v0.3.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 469a3df3..0f1a952c 100644 --- a/go.sum +++ b/go.sum @@ -1,208 +1,42 @@ -github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= -github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/Masterminds/semver/v3 v3.1.1 h1:hLg3sBzpNErnxhQtUy/mmLR2I9foDujNK030IGemrRc= -github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= -github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= -github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= -github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= -github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= -github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/go-kit/log v0.1.0 h1:DGJh0Sm43HbOeYDNnVZFl8BvcYVvjD5bqYJvp0REbwQ= -github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= -github.com/go-logfmt/logfmt v0.5.0 h1:TrB8swr/68K7m9CcGut2g3UOihhbcbiMAYiuTXdEih4= -github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= -github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= -github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= -github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= -github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= -github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= -github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= -github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= -github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= -github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= -github.com/jackc/chunkreader/v2 v2.0.1/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= -github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA= -github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= -github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= -github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o= -github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY= -github.com/jackc/pgconn v1.9.1-0.20210724152538-d89c8390a530/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= -github.com/jackc/pgconn v1.13.0 h1:3L1XMNV2Zvca/8BYhzcRFS70Lr0WlDg16Di6SFGAbys= -github.com/jackc/pgconn v1.13.0/go.mod h1:AnowpAqO4CMIIJNZl2VJp+KrkAZciAkhEl0W0JIobpI= -github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= -github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= -github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= -github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c= -github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 h1:DadwsjnMwFjfWc9y5Wi/+Zz7xoE5ALHsRQlOctkOiHc= -github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= -github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= -github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= -github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= -github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= -github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= -github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= -github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= -github.com/jackc/pgproto3/v2 v2.3.1 h1:nwj7qwf0S+Q7ISFfBndqeLwSwxs+4DPsbRFjECT1Y4Y= -github.com/jackc/pgproto3/v2 v2.3.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= -github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= -github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= -github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= -github.com/jackc/pgtype v1.8.1-0.20210724151600-32e20a603178/go.mod h1:C516IlIV9NKqfsMCXTdChteoXmwgUceqaLfjg2e3NlM= -github.com/jackc/pgtype v1.12.0 h1:Dlq8Qvcch7kiehm8wPGIW0W3KsCCHJnRacKW0UM8n5w= -github.com/jackc/pgtype v1.12.0/go.mod h1:LUMuVrfsFfdKGLw+AFFVv6KtHOFMwRgDDzBt76IqCA4= -github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= -github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= -github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= -github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c/go.mod h1:1QD0+tgSXP7iUjYm9C1NxKhny7lq6ee99u/z+IHFcgs= -github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= -github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= -github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= -github.com/jackc/puddle v1.2.1 h1:gI8os0wpRXFd4FiAY2dWiqRK037tjj3t7rKFeO4X5iw= -github.com/jackc/puddle v1.2.1/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= -github.com/jackc/puddle v1.3.0 h1:eHK/5clGOatcjX3oWGBO/MpxpbHzSwud5EWTSCI+MX0= -github.com/jackc/puddle v1.3.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= -github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= -github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s= -github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= -github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/jackc/puddle/v2 v2.0.0 h1:Kwk/AlLigcnZsDssc3Zun1dk1tAtQNPaBBxBHWn0Mjc= +github.com/jackc/puddle/v2 v2.0.0/go.mod h1:itE7ZJY8xnoo0JqJEpSMprN0f+NQkMCuEV/N9j8h0oc= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= -github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/lib/pq v1.10.2 h1:AqzbZs4ZoCBp+GtejcpCpcxM3zlSMx29dXbUSeVtJb8= -github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= -github.com/mattn/go-colorable v0.1.6 h1:6Su7aK7lXmJ/U79bYtBjLNaha4Fs1Rg9plHpcH+vvnE= -github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= -github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= -github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= -github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= -github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= -github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= -github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= -github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= -github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= -github.com/rs/zerolog v1.15.0 h1:uPRuwkWF4J6fGsJ2R0Gn2jB1EQiav9k3S6CSdygQJXY= -github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= -github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= -github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= -github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= -github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= -github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= -github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4= -github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/rogpeppe/go-internal v1.6.1 h1:/FiVV8dS/e+YqF2JvO3yXRFbBLTIuSDkuC7aBOAvL+k= +github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= -go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= -go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= -go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= -go.uber.org/atomic v1.6.0 h1:Ezj3JGmsOnG1MoRWQkPBsKLe9DwWD9QeXzTRzzldNVk= -go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= -go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= -go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= -go.uber.org/multierr v1.5.0 h1:KCa4XfM8CWFCpxXRGok+Q0SS/0XBhMDbHHGABQLvD2A= -go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= -go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee h1:0mgffUl7nfd+FpvXMVz4IDEaUSmT1ysygQC7qYo7sG4= -go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= -go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= -go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= -go.uber.org/zap v1.13.0 h1:nR6NoDBgAf67s68NhaXbsojM+2gxp3S1hWkHDl27pVU= -go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= -golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= -golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c= -golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs= -golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= -golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 h1:SrN+KX8Art/Sf4HNj6Zcz06G7VEz+7w9tdXTPOZ7+l4= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/crypto v0.0.0-20220829220503-c86fa9a7ed90 h1:Y/gsMcFOcR+6S6f3YeMKl5g+dZMEWqcz5Czj/GWYbkM= +golang.org/x/crypto v0.0.0-20220829220503-c86fa9a7ed90/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= -golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20200103221440-774c71fcf114 h1:DnSr2mCsxyCE6ZgIkmcWUQY2R5cH/6wL7eIxEmQOMSE= -golang.org/x/tools v0.0.0-20200103221440-774c71fcf114/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= -gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec h1:RlWgLqCMMIYYEVcAR5MDsuHlVkaIPDAF+5Dehzg8L5A= -gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM= -honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= diff --git a/go_stdlib.go b/go_stdlib.go deleted file mode 100644 index 9372f9ef..00000000 --- a/go_stdlib.go +++ /dev/null @@ -1,61 +0,0 @@ -package pgx - -import ( - "database/sql/driver" - "reflect" -) - -// This file contains code copied from the Go standard library due to the -// required function not being public. - -// Copyright (c) 2009 The Go Authors. All rights reserved. - -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are -// met: - -// * Redistributions of source code must retain the above copyright -// notice, this list of conditions and the following disclaimer. -// * Redistributions in binary form must reproduce the above -// copyright notice, this list of conditions and the following disclaimer -// in the documentation and/or other materials provided with the -// distribution. -// * Neither the name of Google Inc. nor the names of its -// contributors may be used to endorse or promote products derived from -// this software without specific prior written permission. - -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -// From database/sql/convert.go - -var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() - -// callValuerValue returns vr.Value(), with one exception: -// If vr.Value is an auto-generated method on a pointer type and the -// pointer is nil, it would panic at runtime in the panicwrap -// method. Treat it like nil instead. -// Issue 8415. -// -// This is so people can implement driver.Value on value types and -// still use nil pointers to those types to mean nil/NULL, just like -// string/*string. -// -// This function is mirrored in the database/sql/driver package. -func callValuerValue(vr driver.Valuer) (v driver.Value, err error) { - if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr && - rv.IsNil() && - rv.Type().Elem().Implements(valuerReflectType) { - return nil, nil - } - return vr.Value() -} diff --git a/helper_test.go b/helper_test.go index b9cd21c1..26e54621 100644 --- a/helper_test.go +++ b/helper_test.go @@ -7,48 +7,21 @@ import ( "github.com/stretchr/testify/assert" - "github.com/jackc/pgconn" - "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxtest" "github.com/stretchr/testify/require" ) -func testWithAndWithoutPreferSimpleProtocol(t *testing.T, f func(t *testing.T, conn *pgx.Conn)) { - t.Run("SimpleProto", - func(t *testing.T) { - config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) +var defaultConnTestRunner pgxtest.ConnTestRunner - config.PreferSimpleProtocol = true - conn, err := pgx.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer func() { - err := conn.Close(context.Background()) - require.NoError(t, err) - }() - - f(t, conn) - - ensureConnValid(t, conn) - }, - ) - - t.Run("DefaultProto", - func(t *testing.T) { - config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - - conn, err := pgx.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer func() { - err := conn.Close(context.Background()) - require.NoError(t, err) - }() - - f(t, conn) - - ensureConnValid(t, conn) - }, - ) +func init() { + defaultConnTestRunner = pgxtest.DefaultConnTestRunner() + defaultConnTestRunner.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + return config + } } func mustConnectString(t testing.TB, connString string) *pgx.Conn { @@ -80,7 +53,7 @@ func closeConn(t testing.TB, conn *pgx.Conn) { } } -func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag) { +func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...any) (commandTag pgconn.CommandTag) { var err error if commandTag, err = conn.Exec(context.Background(), sql, arguments...); err != nil { t.Fatalf("Exec unexpectedly failed with %v: %v", sql, err) @@ -89,7 +62,7 @@ func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...interface{} } // Do a simple query to ensure the connection is still usable -func ensureConnValid(t *testing.T, conn *pgx.Conn) { +func ensureConnValid(t testing.TB, conn *pgx.Conn) { var sum, rowCount int32 rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 10) @@ -125,13 +98,11 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgx.ConnConfig, testName return } - assert.Equalf(t, expected.Logger, actual.Logger, "%s - Logger", testName) - assert.Equalf(t, expected.LogLevel, actual.LogLevel, "%s - LogLevel", testName) + assert.Equalf(t, expected.Tracer, actual.Tracer, "%s - Tracer", testName) assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName) - // Can't test function equality, so just test that they are set or not. - assert.Equalf(t, expected.BuildStatementCache == nil, actual.BuildStatementCache == nil, "%s - BuildStatementCache", testName) - assert.Equalf(t, expected.PreferSimpleProtocol, actual.PreferSimpleProtocol, "%s - PreferSimpleProtocol", testName) - + assert.Equalf(t, expected.StatementCacheCapacity, actual.StatementCacheCapacity, "%s - StatementCacheCapacity", testName) + assert.Equalf(t, expected.DescriptionCacheCapacity, actual.DescriptionCacheCapacity, "%s - DescriptionCacheCapacity", testName) + assert.Equalf(t, expected.DefaultQueryExecMode, actual.DefaultQueryExecMode, "%s - DefaultQueryExecMode", testName) assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName) assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName) assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName) @@ -165,9 +136,3 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgx.ConnConfig, testName } } } - -func skipCockroachDB(t testing.TB, conn *pgx.Conn, msg string) { - if conn.PgConn().ParameterStatus("crdb_version") != "" { - t.Skip(msg) - } -} diff --git a/internal/anynil/anynil.go b/internal/anynil/anynil.go new file mode 100644 index 00000000..9a48c1a8 --- /dev/null +++ b/internal/anynil/anynil.go @@ -0,0 +1,36 @@ +package anynil + +import "reflect" + +// Is returns true if value is any type of nil. e.g. nil or []byte(nil). +func Is(value any) bool { + if value == nil { + return true + } + + refVal := reflect.ValueOf(value) + switch refVal.Kind() { + case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice: + return refVal.IsNil() + default: + return false + } +} + +// Normalize converts typed nils (e.g. []byte(nil)) into untyped nil. Other values are returned unmodified. +func Normalize(v any) any { + if Is(v) { + return nil + } + return v +} + +// NormalizeSlice converts all typed nils (e.g. []byte(nil)) in s into untyped nils. Other values are unmodified. s is +// mutated in place. +func NormalizeSlice(s []any) { + for i := range s { + if Is(s[i]) { + s[i] = nil + } + } +} diff --git a/internal/iobufpool/iobufpool.go b/internal/iobufpool/iobufpool.go new file mode 100644 index 00000000..9e55c435 --- /dev/null +++ b/internal/iobufpool/iobufpool.go @@ -0,0 +1,57 @@ +// Package iobufpool implements a global segregated-fit pool of buffers for IO. +package iobufpool + +import "sync" + +const minPoolExpOf2 = 8 + +var pools [18]*sync.Pool + +func init() { + for i := range pools { + bufLen := 1 << (minPoolExpOf2 + i) + pools[i] = &sync.Pool{New: func() any { return make([]byte, bufLen) }} + } +} + +// Get gets a []byte of len size with cap <= size*2. +func Get(size int) []byte { + i := getPoolIdx(size) + if i >= len(pools) { + return make([]byte, size) + } + return pools[i].Get().([]byte)[:size] +} + +func getPoolIdx(size int) int { + size-- + size >>= minPoolExpOf2 + i := 0 + for size > 0 { + size >>= 1 + i++ + } + + return i +} + +// Put returns buf to the pool. +func Put(buf []byte) { + i := putPoolIdx(cap(buf)) + if i < 0 { + return + } + + pools[i].Put(buf) +} + +func putPoolIdx(size int) int { + minPoolSize := 1 << minPoolExpOf2 + for i := range pools { + if size == minPoolSize<= len(bq.queue) { + bq.growQueue() + } + bq.queue[bq.w] = buf + bq.w++ +} + +func (bq *bufferQueue) pushFront(buf []byte) { + bq.lock.Lock() + defer bq.lock.Unlock() + + if bq.w >= len(bq.queue) { + bq.growQueue() + } + copy(bq.queue[bq.r+1:bq.w+1], bq.queue[bq.r:bq.w]) + bq.queue[bq.r] = buf + bq.w++ +} + +func (bq *bufferQueue) popFront() []byte { + bq.lock.Lock() + defer bq.lock.Unlock() + + if bq.r == bq.w { + return nil + } + + buf := bq.queue[bq.r] + bq.queue[bq.r] = nil // Clear reference so it can be garbage collected. + bq.r++ + + if bq.r == bq.w { + bq.r = 0 + bq.w = 0 + if len(bq.queue) > minBufferQueueLen { + bq.queue = make([][]byte, minBufferQueueLen) + } + } + + return buf +} + +func (bq *bufferQueue) growQueue() { + desiredLen := (len(bq.queue) + 1) * 3 / 2 + if desiredLen < minBufferQueueLen { + desiredLen = minBufferQueueLen + } + + newQueue := make([][]byte, desiredLen) + copy(newQueue, bq.queue) + bq.queue = newQueue +} diff --git a/internal/nbconn/nbconn.go b/internal/nbconn/nbconn.go new file mode 100644 index 00000000..c091ea3f --- /dev/null +++ b/internal/nbconn/nbconn.go @@ -0,0 +1,476 @@ +// Package nbconn implements a non-blocking net.Conn wrapper. +// +// It is designed to solve three problems. +// +// The first is resolving the deadlock that can occur when both sides of a connection are blocked writing because all +// buffers between are full. See https://github.com/jackc/pgconn/issues/27 for discussion. +// +// The second is the inability to use a write deadline with a TLS.Conn without killing the connection. +// +// The third is to efficiently check if a connection has been closed via a non-blocking read. +package nbconn + +import ( + "crypto/tls" + "errors" + "net" + "os" + "sync" + "sync/atomic" + "syscall" + "time" + + "github.com/jackc/pgx/v5/internal/iobufpool" +) + +var errClosed = errors.New("closed") +var ErrWouldBlock = new(wouldBlockError) + +const fakeNonblockingWaitDuration = 100 * time.Millisecond + +// NonBlockingDeadline is a magic value that when passed to Set[Read]Deadline places the connection in non-blocking read +// mode. +var NonBlockingDeadline = time.Date(1900, 1, 1, 0, 0, 0, 608536336, time.UTC) + +// disableSetDeadlineDeadline is a magic value that when passed to Set[Read|Write]Deadline causes those methods to +// ignore all future calls. +var disableSetDeadlineDeadline = time.Date(1900, 1, 1, 0, 0, 0, 968549727, time.UTC) + +// wouldBlockError implements net.Error so tls.Conn will recognize ErrWouldBlock as a temporary error. +type wouldBlockError struct{} + +func (*wouldBlockError) Error() string { + return "would block" +} + +func (*wouldBlockError) Timeout() bool { return true } +func (*wouldBlockError) Temporary() bool { return true } + +// Conn is a net.Conn where Write never blocks and always succeeds. Flush or Read must be called to actually write to +// the underlying connection. +type Conn interface { + net.Conn + + // Flush flushes any buffered writes. + Flush() error + + // BufferReadUntilBlock reads and buffers any sucessfully read bytes until the read would block. + BufferReadUntilBlock() error +} + +// NetConn is a non-blocking net.Conn wrapper. It implements net.Conn. +type NetConn struct { + conn net.Conn + rawConn syscall.RawConn + + readQueue bufferQueue + writeQueue bufferQueue + + readFlushLock sync.Mutex + // non-blocking writes with syscall.RawConn are done with a callback function. By using these fields instead of the + // callback functions closure to pass the buf argument and receive the n and err results we avoid some allocations. + nonblockWriteBuf []byte + nonblockWriteErr error + nonblockWriteN int + + readDeadlineLock sync.Mutex + readDeadline time.Time + readNonblocking bool + + writeDeadlineLock sync.Mutex + writeDeadline time.Time + + // Only access with atomics + closed int64 // 0 = not closed, 1 = closed +} + +func NewNetConn(conn net.Conn, fakeNonBlockingIO bool) *NetConn { + nc := &NetConn{ + conn: conn, + } + + if !fakeNonBlockingIO { + if sc, ok := conn.(syscall.Conn); ok { + if rawConn, err := sc.SyscallConn(); err == nil { + nc.rawConn = rawConn + } + } + } + + return nc +} + +// Read implements io.Reader. +func (c *NetConn) Read(b []byte) (n int, err error) { + if c.isClosed() { + return 0, errClosed + } + + c.readFlushLock.Lock() + defer c.readFlushLock.Unlock() + + err = c.flush() + if err != nil { + return 0, err + } + + for n < len(b) { + buf := c.readQueue.popFront() + if buf == nil { + break + } + copiedN := copy(b[n:], buf) + if copiedN < len(buf) { + buf = buf[copiedN:] + c.readQueue.pushFront(buf) + } else { + iobufpool.Put(buf) + } + n += copiedN + } + + // If any bytes were already buffered return them without trying to do a Read. Otherwise, when the caller is trying to + // Read up to len(b) bytes but all available bytes have already been buffered the underlying Read would block. + if n > 0 { + return n, nil + } + + var readNonblocking bool + c.readDeadlineLock.Lock() + readNonblocking = c.readNonblocking + c.readDeadlineLock.Unlock() + + var readN int + if readNonblocking { + readN, err = c.nonblockingRead(b[n:]) + } else { + readN, err = c.conn.Read(b[n:]) + } + n += readN + return n, err +} + +// Write implements io.Writer. It never blocks due to buffering all writes. It will only return an error if the Conn is +// closed. Call Flush to actually write to the underlying connection. +func (c *NetConn) Write(b []byte) (n int, err error) { + if c.isClosed() { + return 0, errClosed + } + + buf := iobufpool.Get(len(b)) + copy(buf, b) + c.writeQueue.pushBack(buf) + return len(b), nil +} + +func (c *NetConn) Close() (err error) { + swapped := atomic.CompareAndSwapInt64(&c.closed, 0, 1) + if !swapped { + return errClosed + } + + defer func() { + closeErr := c.conn.Close() + if err == nil { + err = closeErr + } + }() + + c.readFlushLock.Lock() + defer c.readFlushLock.Unlock() + err = c.flush() + if err != nil { + return err + } + + return nil +} + +func (c *NetConn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *NetConn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +// SetDeadline is the equivalent of calling SetReadDealine(t) and SetWriteDeadline(t). +func (c *NetConn) SetDeadline(t time.Time) error { + err := c.SetReadDeadline(t) + if err != nil { + return err + } + return c.SetWriteDeadline(t) +} + +// SetReadDeadline sets the read deadline as t. If t == NonBlockingDeadline then future reads will be non-blocking. +func (c *NetConn) SetReadDeadline(t time.Time) error { + if c.isClosed() { + return errClosed + } + + c.readDeadlineLock.Lock() + defer c.readDeadlineLock.Unlock() + if c.readDeadline == disableSetDeadlineDeadline { + return nil + } + if t == disableSetDeadlineDeadline { + c.readDeadline = t + return nil + } + + if t == NonBlockingDeadline { + c.readNonblocking = true + t = time.Time{} + } else { + c.readNonblocking = false + } + + c.readDeadline = t + + return c.conn.SetReadDeadline(t) +} + +func (c *NetConn) SetWriteDeadline(t time.Time) error { + if c.isClosed() { + return errClosed + } + + c.writeDeadlineLock.Lock() + defer c.writeDeadlineLock.Unlock() + if c.writeDeadline == disableSetDeadlineDeadline { + return nil + } + if t == disableSetDeadlineDeadline { + c.writeDeadline = t + return nil + } + + c.writeDeadline = t + + return c.conn.SetWriteDeadline(t) +} + +func (c *NetConn) Flush() error { + if c.isClosed() { + return errClosed + } + + c.readFlushLock.Lock() + defer c.readFlushLock.Unlock() + return c.flush() +} + +// flush does the actual work of flushing the writeQueue. readFlushLock must already be held. +func (c *NetConn) flush() error { + var stopChan chan struct{} + var errChan chan error + + defer func() { + if stopChan != nil { + select { + case stopChan <- struct{}{}: + case <-errChan: + } + } + }() + + for buf := c.writeQueue.popFront(); buf != nil; buf = c.writeQueue.popFront() { + remainingBuf := buf + for len(remainingBuf) > 0 { + n, err := c.nonblockingWrite(remainingBuf) + remainingBuf = remainingBuf[n:] + if err != nil { + if !errors.Is(err, ErrWouldBlock) { + buf = buf[:len(remainingBuf)] + copy(buf, remainingBuf) + c.writeQueue.pushFront(buf) + return err + } + + // Writing was blocked. Reading might unblock it. + if stopChan == nil { + stopChan, errChan = c.bufferNonblockingRead() + } + + select { + case err := <-errChan: + stopChan = nil + return err + default: + } + + } + } + iobufpool.Put(buf) + } + + return nil +} + +func (c *NetConn) BufferReadUntilBlock() error { + for { + buf := iobufpool.Get(8 * 1024) + n, err := c.nonblockingRead(buf) + if n > 0 { + buf = buf[:n] + c.readQueue.pushBack(buf) + } + + if err != nil { + if errors.Is(err, ErrWouldBlock) { + return nil + } else { + return err + } + } + } +} + +func (c *NetConn) bufferNonblockingRead() (stopChan chan struct{}, errChan chan error) { + stopChan = make(chan struct{}) + errChan = make(chan error, 1) + + go func() { + for { + err := c.BufferReadUntilBlock() + if err != nil { + errChan <- err + return + } + + select { + case <-stopChan: + return + default: + } + } + }() + + return stopChan, errChan +} + +func (c *NetConn) isClosed() bool { + closed := atomic.LoadInt64(&c.closed) + return closed == 1 +} + +func (c *NetConn) nonblockingWrite(b []byte) (n int, err error) { + if c.rawConn == nil { + return c.fakeNonblockingWrite(b) + } else { + return c.realNonblockingWrite(b) + } +} + +func (c *NetConn) fakeNonblockingWrite(b []byte) (n int, err error) { + c.writeDeadlineLock.Lock() + defer c.writeDeadlineLock.Unlock() + + deadline := time.Now().Add(fakeNonblockingWaitDuration) + if c.writeDeadline.IsZero() || deadline.Before(c.writeDeadline) { + err = c.conn.SetWriteDeadline(deadline) + if err != nil { + return 0, err + } + defer func() { + // Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails. + c.conn.SetWriteDeadline(c.writeDeadline) + + if err != nil { + if errors.Is(err, os.ErrDeadlineExceeded) { + err = ErrWouldBlock + } + } + }() + } + + return c.conn.Write(b) +} + +func (c *NetConn) nonblockingRead(b []byte) (n int, err error) { + if c.rawConn == nil { + return c.fakeNonblockingRead(b) + } else { + return c.realNonblockingRead(b) + } +} + +func (c *NetConn) fakeNonblockingRead(b []byte) (n int, err error) { + c.readDeadlineLock.Lock() + defer c.readDeadlineLock.Unlock() + + deadline := time.Now().Add(fakeNonblockingWaitDuration) + if c.readDeadline.IsZero() || deadline.Before(c.readDeadline) { + err = c.conn.SetReadDeadline(deadline) + if err != nil { + return 0, err + } + defer func() { + // Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails. + c.conn.SetReadDeadline(c.readDeadline) + + if err != nil { + if errors.Is(err, os.ErrDeadlineExceeded) { + err = ErrWouldBlock + } + } + }() + } + + return c.conn.Read(b) +} + +// syscall.Conn is interface + +// TLSClient establishes a TLS connection as a client over conn using config. +// +// To avoid the first Read on the returned *TLSConn also triggering a Write due to the TLS handshake and thereby +// potentially causing a read and write deadlines to behave unexpectedly, Handshake is called explicitly before the +// *TLSConn is returned. +func TLSClient(conn *NetConn, config *tls.Config) (*TLSConn, error) { + tc := tls.Client(conn, config) + err := tc.Handshake() + if err != nil { + return nil, err + } + + // Ensure last written part of Handshake is actually sent. + err = conn.Flush() + if err != nil { + return nil, err + } + + return &TLSConn{ + tlsConn: tc, + nbConn: conn, + }, nil +} + +// TLSConn is a TLS wrapper around a *Conn. It works around a temporary write error (such as a timeout) being fatal to a +// tls.Conn. +type TLSConn struct { + tlsConn *tls.Conn + nbConn *NetConn +} + +func (tc *TLSConn) Read(b []byte) (n int, err error) { return tc.tlsConn.Read(b) } +func (tc *TLSConn) Write(b []byte) (n int, err error) { return tc.tlsConn.Write(b) } +func (tc *TLSConn) BufferReadUntilBlock() error { return tc.nbConn.BufferReadUntilBlock() } +func (tc *TLSConn) Flush() error { return tc.nbConn.Flush() } +func (tc *TLSConn) LocalAddr() net.Addr { return tc.tlsConn.LocalAddr() } +func (tc *TLSConn) RemoteAddr() net.Addr { return tc.tlsConn.RemoteAddr() } + +func (tc *TLSConn) Close() error { + // tls.Conn.closeNotify() sets a 5 second deadline to avoid blocking, sends a TLS alert close notification, and then + // sets the deadline to now. This causes NetConn's Close not to be able to flush the write buffer. Instead we set our + // own 5 second deadline then make all set deadlines no-op. + tc.tlsConn.SetDeadline(time.Now().Add(time.Second * 5)) + tc.tlsConn.SetDeadline(disableSetDeadlineDeadline) + + return tc.tlsConn.Close() +} + +func (tc *TLSConn) SetDeadline(t time.Time) error { return tc.tlsConn.SetDeadline(t) } +func (tc *TLSConn) SetReadDeadline(t time.Time) error { return tc.tlsConn.SetReadDeadline(t) } +func (tc *TLSConn) SetWriteDeadline(t time.Time) error { return tc.tlsConn.SetWriteDeadline(t) } diff --git a/internal/nbconn/nbconn_fake_non_block.go b/internal/nbconn/nbconn_fake_non_block.go new file mode 100644 index 00000000..cf05df1c --- /dev/null +++ b/internal/nbconn/nbconn_fake_non_block.go @@ -0,0 +1,13 @@ +//go:build !(aix || android || darwin || dragonfly || freebsd || hurd || illumos || ios || linux || netbsd || openbsd || solaris) + +package nbconn + +// Not using unix build tag for support on Go 1.18. + +func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) { + return c.fakeNonblockingWrite(b) +} + +func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) { + return c.fakeNonblockingRead(b) +} diff --git a/internal/nbconn/nbconn_real_non_block.go b/internal/nbconn/nbconn_real_non_block.go new file mode 100644 index 00000000..ee48d129 --- /dev/null +++ b/internal/nbconn/nbconn_real_non_block.go @@ -0,0 +1,70 @@ +//go:build aix || android || darwin || dragonfly || freebsd || hurd || illumos || ios || linux || netbsd || openbsd || solaris + +package nbconn + +// Not using unix build tag for support on Go 1.18. + +import ( + "errors" + "io" + "syscall" +) + +// realNonblockingWrite does a non-blocking write. readFlushLock must already be held. +func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) { + c.nonblockWriteBuf = b + c.nonblockWriteN = 0 + c.nonblockWriteErr = nil + err = c.rawConn.Write(func(fd uintptr) (done bool) { + c.nonblockWriteN, c.nonblockWriteErr = syscall.Write(int(fd), c.nonblockWriteBuf) + return true + }) + n = c.nonblockWriteN + if err == nil && c.nonblockWriteErr != nil { + if errors.Is(c.nonblockWriteErr, syscall.EWOULDBLOCK) { + err = ErrWouldBlock + } else { + err = c.nonblockWriteErr + } + } + if err != nil { + // n may be -1 when an error occurs. + if n < 0 { + n = 0 + } + + return n, err + } + + return n, nil +} + +func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) { + var funcErr error + err = c.rawConn.Read(func(fd uintptr) (done bool) { + n, funcErr = syscall.Read(int(fd), b) + return true + }) + if err == nil && funcErr != nil { + if errors.Is(funcErr, syscall.EWOULDBLOCK) { + err = ErrWouldBlock + } else { + err = funcErr + } + } + if err != nil { + // n may be -1 when an error occurs. + if n < 0 { + n = 0 + } + + return n, err + } + + // syscall read did not return an error and 0 bytes were read means EOF. + if n == 0 { + return 0, io.EOF + } + + return n, nil +} diff --git a/internal/nbconn/nbconn_test.go b/internal/nbconn/nbconn_test.go new file mode 100644 index 00000000..8b672e4e --- /dev/null +++ b/internal/nbconn/nbconn_test.go @@ -0,0 +1,584 @@ +package nbconn_test + +import ( + "crypto/tls" + "errors" + "io" + "net" + "strings" + "testing" + "time" + + "github.com/jackc/pgx/v5/internal/nbconn" + "github.com/stretchr/testify/require" +) + +// Test keys generated with: +// +// $ openssl req -x509 -newkey rsa:2048 -keyout key.pem -out cert.pem -sha256 -nodes -days 20000 -subj '/CN=localhost' + +var testTLSPublicKey = []byte(`-----BEGIN CERTIFICATE----- +MIICpjCCAY4CCQCjQKYdUDQzKDANBgkqhkiG9w0BAQsFADAUMRIwEAYDVQQDDAls +b2NhbGhvc3QwIBcNMjIwNjA0MTY1MzE2WhgPMjA3NzAzMDcxNjUzMTZaMBQxEjAQ +BgNVBAMMCWxvY2FsaG9zdDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB +ALHbOu80cfSPufKTZsKf3E5rCXHeIHjaIbgHEXA2SW/n77U8oZX518s+27FO0sK5 +yA0WnEIwY34PU359sNR5KelARGnaeh3HdaGm1nuyyxBtwwAqIuM0UxGAMF/mQ4lT +caZPxG+7WlYDqnE3eVXUtG4c+T7t5qKAB3MtfbzKFSjczkWkroi6cTypmHArGghT +0VWWVu0s9oNp5q8iWchY2o9f0aIjmKv6FgtilO+geev+4U+QvtvrziR5BO3/3EgW +c5TUVcf+lwkvp8ziXvargmjjnNTyeF37y4KpFcex0v7z7hSrUK4zU0+xRn7Bp17v +7gzj0xN+HCsUW1cjPFNezX0CAwEAATANBgkqhkiG9w0BAQsFAAOCAQEAbEBzewzg +Z5F+BqMSxP3HkMCkLLH0N9q0/DkZaVyZ38vrjcjaDYuabq28kA2d5dc5jxsQpvTw +HTGqSv1ZxJP3pBFv6jLSh8xaM6tUkk482Q6DnZGh97CD4yup/yJzkn5nv9OHtZ9g +TnaQeeXgOz0o5Zq9IpzHJb19ysya3UCIK8oKXbSO4Qd168seCq75V2BFHDpmejjk +D92eT6WODlzzvZbhzA1F3/cUilZdhbQtJMqdecKvD+yrBpzGVqzhWQsXwsRAU1fB +hShx+D14zUGM2l4wlVzOAuGh4ZL7x3AjJsc86TsCavTspS0Xl51j+mRbiULq7G7Y +E7ZYmaKTMOhvkg== +-----END CERTIFICATE-----`) + +// The strings.ReplaceAll is used to placate any secret scanners that would squawk if they saw a private key embedded in +// source code. +var testTLSPrivateKey = []byte(strings.ReplaceAll(`-----BEGIN TESTING KEY----- +MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQCx2zrvNHH0j7ny +k2bCn9xOawlx3iB42iG4BxFwNklv5++1PKGV+dfLPtuxTtLCucgNFpxCMGN+D1N+ +fbDUeSnpQERp2nodx3WhptZ7sssQbcMAKiLjNFMRgDBf5kOJU3GmT8Rvu1pWA6px +N3lV1LRuHPk+7eaigAdzLX28yhUo3M5FpK6IunE8qZhwKxoIU9FVllbtLPaDaeav +IlnIWNqPX9GiI5ir+hYLYpTvoHnr/uFPkL7b684keQTt/9xIFnOU1FXH/pcJL6fM +4l72q4Jo45zU8nhd+8uCqRXHsdL+8+4Uq1CuM1NPsUZ+wade7+4M49MTfhwrFFtX +IzxTXs19AgMBAAECggEBAJcHt5ARVQN8WUbobMawwX/F3QtYuPJnKWMAfYpwTwQ8 +TI32orCcrObmxeBXMxowcPTMUnzSYmpV0W0EhvimuzRbYr0Qzcoj6nwPFOuN9GpL +CuBE58NQV4nw9SM6gfdHaKb17bWDvz5zdnUVym9cZKts5yrNEqDDX5Aq/S8n27gJ +/qheXwSxwETVO6kMEW1ndNIWDP8DPQ0E4O//RuMZwxpnZdnjGKkdVNy8I1BpgDgn +lwgkE3H3IciASki1GYXoyvrIiRwMQVzvYD2zcgwK9OZSjZe0TGwAGa+eQdbs3A9I +Ir1kYn6ZMGMRFJA2XHJW3hMZdWB/t2xMBGy75Uv9sAECgYEA1o+oRUYwwQ1MwBo9 +YA6c00KjhFgrjdzyKPQrN14Q0dw5ErqRkhp2cs7BRdCDTDrjAegPc3Otg7uMa1vp +RgU/C72jwzFLYATvn+RLGRYRyqIE+bQ22/lLnXTrp4DCfdMrqWuQbIYouGHqfQrq +MfdtSUpQ6VZCi9zHehXOYwBMvQECgYEA1DTQFpe+tndIFmguxxaBwDltoPh5omzd +3vA7iFct2+UYk5W9shfAekAaZk2WufKmmC3OfBWYyIaJ7QwQpuGDS3zwjy6WFMTE +Otp2CypFCVahwHcvn2jYHmDMT0k0Pt6X2S3GAyWTyEPv7mAfKR1OWUYi7ZgdXpt0 +TtL3Z3JyhH0CgYEAwveHUGuXodUUCPvPCZo9pzrGm1wDN8WtxskY/Bbd8dTLh9lA +riKdv3Vg6q+un3ZjETht0dsrsKib0HKUZqwdve11AcmpVHcnx4MLOqBzSk4vdzfr +IbhGna3A9VRrZyqcYjb75aGDHwjaqwVgCkdrZ03AeEeJ8M2N9cIa6Js9IAECgYBu +nlU24cVdspJWc9qml3ntrUITnlMxs1R5KXuvF9rk/OixzmYDV1RTpeTdHWcL6Yyk +WYSAtHVfWpq9ggOQKpBZonh3+w3rJ6MvFsBgE5nHQ2ywOrENhQbb1xPJ5NwiRcCc +Srsk2srNo3SIK30y3n8AFIqSljABKEIZ8Olc+JDvtQKBgQCiKz43zI6a0HscgZ77 +DCBduWP4nk8BM7QTFxs9VypjrylMDGGtTKHc5BLA5fNZw97Hb7pcicN7/IbUnQUD +pz01y53wMSTJs0ocAxkYvUc5laF+vMsLpG2vp8f35w8uKuO7+vm5LAjUsPd099jG +2qWm8jTPeDC3sq+67s2oojHf+Q== +-----END TESTING KEY-----`, "TESTING KEY", "PRIVATE KEY")) + +func testVariants(t *testing.T, f func(t *testing.T, local nbconn.Conn, remote net.Conn)) { + for _, tt := range []struct { + name string + makeConns func(t *testing.T) (local, remote net.Conn) + useTLS bool + fakeNonBlockingIO bool + }{ + { + name: "Pipe", + makeConns: makePipeConns, + useTLS: false, + fakeNonBlockingIO: true, + }, + { + name: "TCP with Fake Non-blocking IO", + makeConns: makeTCPConns, + useTLS: false, + fakeNonBlockingIO: true, + }, + { + name: "TLS over TCP with Fake Non-blocking IO", + makeConns: makeTCPConns, + useTLS: true, + fakeNonBlockingIO: true, + }, + { + name: "TCP with Real Non-blocking IO", + makeConns: makeTCPConns, + useTLS: false, + fakeNonBlockingIO: false, + }, + { + name: "TLS over TCP with Real Non-blocking IO", + makeConns: makeTCPConns, + useTLS: true, + fakeNonBlockingIO: false, + }, + } { + t.Run(tt.name, func(t *testing.T) { + local, remote := tt.makeConns(t) + + // Just to be sure both ends get closed. Also, it retains a reference so one side of the connection doesn't get + // garbage collected. This could happen when a test is testing against a non-responsive remote. Since it never + // uses remote it may be garbage collected leading to the connection being closed. + defer local.Close() + defer remote.Close() + + var conn nbconn.Conn + netConn := nbconn.NewNetConn(local, tt.fakeNonBlockingIO) + + if tt.useTLS { + cert, err := tls.X509KeyPair(testTLSPublicKey, testTLSPrivateKey) + require.NoError(t, err) + + tlsServer := tls.Server(remote, &tls.Config{ + Certificates: []tls.Certificate{cert}, + }) + serverTLSHandshakeChan := make(chan error) + go func() { + err := tlsServer.Handshake() + serverTLSHandshakeChan <- err + }() + + tlsConn, err := nbconn.TLSClient(netConn, &tls.Config{InsecureSkipVerify: true}) + require.NoError(t, err) + conn = tlsConn + + err = <-serverTLSHandshakeChan + require.NoError(t, err) + remote = tlsServer + } else { + conn = netConn + } + + f(t, conn, remote) + }) + } +} + +// makePipeConns returns a connected pair of net.Conns created with net.Pipe(). It is entirely synchronous so it is +// useful for testing an exact sequence of reads and writes with the underlying connection blocking. +func makePipeConns(t *testing.T) (local, remote net.Conn) { + local, remote = net.Pipe() + t.Cleanup(func() { + local.Close() + remote.Close() + }) + + return local, remote +} + +// makeTCPConns returns a connected pair of net.Conns running over TCP on localhost. +func makeTCPConns(t *testing.T) (local, remote net.Conn) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + type acceptResultT struct { + conn net.Conn + err error + } + acceptChan := make(chan acceptResultT) + + go func() { + conn, err := ln.Accept() + acceptChan <- acceptResultT{conn: conn, err: err} + }() + + local, err = net.Dial("tcp", ln.Addr().String()) + require.NoError(t, err) + + acceptResult := <-acceptChan + require.NoError(t, acceptResult.err) + + remote = acceptResult.conn + + return local, remote +} + +func TestWriteIsBuffered(t *testing.T) { + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { + // net.Pipe is synchronous so the Write would block if not buffered. + writeBuf := []byte("test") + n, err := conn.Write(writeBuf) + require.NoError(t, err) + require.EqualValues(t, 4, n) + + errChan := make(chan error, 1) + go func() { + err := conn.Flush() + errChan <- err + }() + + readBuf := make([]byte, len(writeBuf)) + _, err = remote.Read(readBuf) + require.NoError(t, err) + + require.NoError(t, <-errChan) + }) +} + +func TestSetWriteDeadlineDoesNotBlockWrite(t *testing.T) { + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { + err := conn.SetWriteDeadline(time.Now()) + require.NoError(t, err) + + writeBuf := []byte("test") + n, err := conn.Write(writeBuf) + require.NoError(t, err) + require.EqualValues(t, 4, n) + }) +} + +func TestReadFlushesWriteBuffer(t *testing.T) { + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { + writeBuf := []byte("test") + n, err := conn.Write(writeBuf) + require.NoError(t, err) + require.EqualValues(t, 4, n) + + errChan := make(chan error, 2) + go func() { + readBuf := make([]byte, len(writeBuf)) + _, err := remote.Read(readBuf) + errChan <- err + + _, err = remote.Write([]byte("okay")) + errChan <- err + }() + + readBuf := make([]byte, 4) + _, err = conn.Read(readBuf) + require.NoError(t, err) + require.Equal(t, []byte("okay"), readBuf) + + require.NoError(t, <-errChan) + require.NoError(t, <-errChan) + }) +} + +func TestCloseFlushesWriteBuffer(t *testing.T) { + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { + writeBuf := []byte("test") + n, err := conn.Write(writeBuf) + require.NoError(t, err) + require.EqualValues(t, 4, n) + + errChan := make(chan error, 1) + go func() { + readBuf := make([]byte, len(writeBuf)) + _, err := remote.Read(readBuf) + errChan <- err + }() + + err = conn.Close() + require.NoError(t, err) + + require.NoError(t, <-errChan) + }) +} + +// This test exercises the non-blocking write path. Because writes are buffered it is difficult trigger this with +// certainty and visibility. So this test tries to trigger what would otherwise be a deadlock by both sides writing +// large values. +func TestInternalNonBlockingWrite(t *testing.T) { + const deadlockSize = 4 * 1024 * 1024 + + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { + writeBuf := make([]byte, deadlockSize) + n, err := conn.Write(writeBuf) + require.NoError(t, err) + require.EqualValues(t, deadlockSize, n) + + errChan := make(chan error, 1) + go func() { + remoteWriteBuf := make([]byte, deadlockSize) + _, err := remote.Write(remoteWriteBuf) + if err != nil { + errChan <- err + return + } + + readBuf := make([]byte, deadlockSize) + _, err = io.ReadFull(remote, readBuf) + errChan <- err + }() + + readBuf := make([]byte, deadlockSize) + _, err = io.ReadFull(conn, readBuf) + require.NoError(t, err) + + err = conn.Close() + require.NoError(t, err) + + require.NoError(t, <-errChan) + }) +} + +func TestInternalNonBlockingWriteWithDeadline(t *testing.T) { + const deadlockSize = 4 * 1024 * 1024 + + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { + writeBuf := make([]byte, deadlockSize) + n, err := conn.Write(writeBuf) + require.NoError(t, err) + require.EqualValues(t, deadlockSize, n) + + err = conn.SetDeadline(time.Now().Add(100 * time.Millisecond)) + require.NoError(t, err) + + err = conn.Flush() + require.Error(t, err) + require.Contains(t, err.Error(), "i/o timeout") + }) +} + +func TestNonBlockingRead(t *testing.T) { + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { + err := conn.SetReadDeadline(nbconn.NonBlockingDeadline) + require.NoError(t, err) + + buf := make([]byte, 4) + n, err := conn.Read(buf) + require.ErrorIs(t, err, nbconn.ErrWouldBlock) + require.EqualValues(t, 0, n) + + errChan := make(chan error, 1) + go func() { + _, err := remote.Write([]byte("okay")) + errChan <- err + }() + + err = conn.SetReadDeadline(time.Time{}) + require.NoError(t, err) + + n, err = conn.Read(buf) + require.NoError(t, err) + require.EqualValues(t, 4, n) + }) +} + +func TestBufferNonBlockingRead(t *testing.T) { + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { + err := conn.BufferReadUntilBlock() + require.NoError(t, err) + + errChan := make(chan error, 1) + go func() { + _, err := remote.Write([]byte("okay")) + errChan <- err + }() + + for i := 0; i < 1000; i++ { + err = conn.BufferReadUntilBlock() + if !errors.Is(err, nbconn.ErrWouldBlock) { + break + } + time.Sleep(time.Millisecond) + } + require.NoError(t, err) + + buf := make([]byte, 4) + n, err := conn.Read(buf) + require.NoError(t, err) + require.EqualValues(t, 4, n) + require.Equal(t, []byte("okay"), buf) + }) +} + +func TestReadPreviouslyBuffered(t *testing.T) { + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { + + errChan := make(chan error, 1) + go func() { + err := func() error { + _, err := remote.Write([]byte("alpha")) + if err != nil { + return err + } + + readBuf := make([]byte, 4) + _, err = remote.Read(readBuf) + if err != nil { + return err + } + + return nil + }() + errChan <- err + }() + + _, err := conn.Write([]byte("test")) + require.NoError(t, err) + + // Because net.Pipe() is synchronous conn.Flush must buffer a read. + err = conn.Flush() + require.NoError(t, err) + + readBuf := make([]byte, 5) + n, err := conn.Read(readBuf) + require.NoError(t, err) + require.EqualValues(t, 5, n) + require.Equal(t, []byte("alpha"), readBuf) + }) +} + +func TestReadMoreThanPreviouslyBufferedDoesNotBlock(t *testing.T) { + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { + errChan := make(chan error, 1) + go func() { + err := func() error { + _, err := remote.Write([]byte("alpha")) + if err != nil { + return err + } + + readBuf := make([]byte, 4) + _, err = remote.Read(readBuf) + if err != nil { + return err + } + + return nil + }() + errChan <- err + }() + + _, err := conn.Write([]byte("test")) + require.NoError(t, err) + + // Because net.Pipe() is synchronous conn.Flush must buffer a read. + err = conn.Flush() + require.NoError(t, err) + + readBuf := make([]byte, 10) + n, err := conn.Read(readBuf) + require.NoError(t, err) + require.EqualValues(t, 5, n) + require.Equal(t, []byte("alpha"), readBuf[:n]) + }) +} + +func TestReadPreviouslyBufferedPartialRead(t *testing.T) { + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { + + errChan := make(chan error, 1) + go func() { + err := func() error { + _, err := remote.Write([]byte("alpha")) + if err != nil { + return err + } + + readBuf := make([]byte, 4) + _, err = remote.Read(readBuf) + if err != nil { + return err + } + + return nil + }() + errChan <- err + }() + + _, err := conn.Write([]byte("test")) + require.NoError(t, err) + + // Because net.Pipe() is synchronous conn.Flush must buffer a read. + err = conn.Flush() + require.NoError(t, err) + + readBuf := make([]byte, 2) + n, err := conn.Read(readBuf) + require.NoError(t, err) + require.EqualValues(t, 2, n) + require.Equal(t, []byte("al"), readBuf) + + readBuf = make([]byte, 3) + n, err = conn.Read(readBuf) + require.NoError(t, err) + require.EqualValues(t, 3, n) + require.Equal(t, []byte("pha"), readBuf) + }) +} + +func TestReadMultiplePreviouslyBuffered(t *testing.T) { + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { + errChan := make(chan error, 1) + go func() { + err := func() error { + _, err := remote.Write([]byte("alpha")) + if err != nil { + return err + } + + _, err = remote.Write([]byte("beta")) + if err != nil { + return err + } + + readBuf := make([]byte, 4) + _, err = remote.Read(readBuf) + if err != nil { + return err + } + + return nil + }() + errChan <- err + }() + + _, err := conn.Write([]byte("test")) + require.NoError(t, err) + + // Because net.Pipe() is synchronous conn.Flush must buffer a read. + err = conn.Flush() + require.NoError(t, err) + + readBuf := make([]byte, 9) + n, err := io.ReadFull(conn, readBuf) + require.NoError(t, err) + require.EqualValues(t, 9, n) + require.Equal(t, []byte("alphabeta"), readBuf) + }) +} + +func TestReadPreviouslyBufferedAndReadMore(t *testing.T) { + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { + + flushCompleteChan := make(chan struct{}) + errChan := make(chan error, 1) + go func() { + err := func() error { + _, err := remote.Write([]byte("alpha")) + if err != nil { + return err + } + + readBuf := make([]byte, 4) + _, err = remote.Read(readBuf) + if err != nil { + return err + } + + <-flushCompleteChan + + _, err = remote.Write([]byte("beta")) + if err != nil { + return err + } + + return nil + }() + errChan <- err + }() + + _, err := conn.Write([]byte("test")) + require.NoError(t, err) + + // Because net.Pipe() is synchronous conn.Flush must buffer a read. + err = conn.Flush() + require.NoError(t, err) + + close(flushCompleteChan) + + readBuf := make([]byte, 9) + + n, err := io.ReadFull(conn, readBuf) + require.NoError(t, err) + require.EqualValues(t, 9, n) + require.Equal(t, []byte("alphabeta"), readBuf) + + err = <-errChan + require.NoError(t, err) + }) +} diff --git a/internal/pgio/README.md b/internal/pgio/README.md new file mode 100644 index 00000000..b2fc5801 --- /dev/null +++ b/internal/pgio/README.md @@ -0,0 +1,6 @@ +# pgio + +Package pgio is a low-level toolkit building messages in the PostgreSQL wire protocol. + +pgio provides functions for appending integers to a []byte while doing byte +order conversion. diff --git a/internal/pgio/doc.go b/internal/pgio/doc.go new file mode 100644 index 00000000..ef2dcc7f --- /dev/null +++ b/internal/pgio/doc.go @@ -0,0 +1,6 @@ +// Package pgio is a low-level toolkit building messages in the PostgreSQL wire protocol. +/* +pgio provides functions for appending integers to a []byte while doing byte +order conversion. +*/ +package pgio diff --git a/internal/pgio/write.go b/internal/pgio/write.go new file mode 100644 index 00000000..96aedf9d --- /dev/null +++ b/internal/pgio/write.go @@ -0,0 +1,40 @@ +package pgio + +import "encoding/binary" + +func AppendUint16(buf []byte, n uint16) []byte { + wp := len(buf) + buf = append(buf, 0, 0) + binary.BigEndian.PutUint16(buf[wp:], n) + return buf +} + +func AppendUint32(buf []byte, n uint32) []byte { + wp := len(buf) + buf = append(buf, 0, 0, 0, 0) + binary.BigEndian.PutUint32(buf[wp:], n) + return buf +} + +func AppendUint64(buf []byte, n uint64) []byte { + wp := len(buf) + buf = append(buf, 0, 0, 0, 0, 0, 0, 0, 0) + binary.BigEndian.PutUint64(buf[wp:], n) + return buf +} + +func AppendInt16(buf []byte, n int16) []byte { + return AppendUint16(buf, uint16(n)) +} + +func AppendInt32(buf []byte, n int32) []byte { + return AppendUint32(buf, uint32(n)) +} + +func AppendInt64(buf []byte, n int64) []byte { + return AppendUint64(buf, uint64(n)) +} + +func SetInt32(buf []byte, n int32) { + binary.BigEndian.PutUint32(buf, uint32(n)) +} diff --git a/internal/pgio/write_test.go b/internal/pgio/write_test.go new file mode 100644 index 00000000..bd50e71c --- /dev/null +++ b/internal/pgio/write_test.go @@ -0,0 +1,78 @@ +package pgio + +import ( + "reflect" + "testing" +) + +func TestAppendUint16NilBuf(t *testing.T) { + buf := AppendUint16(nil, 1) + if !reflect.DeepEqual(buf, []byte{0, 1}) { + t.Errorf("AppendUint16(nil, 1) => %v, want %v", buf, []byte{0, 1}) + } +} + +func TestAppendUint16EmptyBuf(t *testing.T) { + buf := []byte{} + buf = AppendUint16(buf, 1) + if !reflect.DeepEqual(buf, []byte{0, 1}) { + t.Errorf("AppendUint16(nil, 1) => %v, want %v", buf, []byte{0, 1}) + } +} + +func TestAppendUint16BufWithCapacityDoesNotAllocate(t *testing.T) { + buf := make([]byte, 0, 4) + AppendUint16(buf, 1) + buf = buf[0:2] + if !reflect.DeepEqual(buf, []byte{0, 1}) { + t.Errorf("AppendUint16(nil, 1) => %v, want %v", buf, []byte{0, 1}) + } +} + +func TestAppendUint32NilBuf(t *testing.T) { + buf := AppendUint32(nil, 1) + if !reflect.DeepEqual(buf, []byte{0, 0, 0, 1}) { + t.Errorf("AppendUint32(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 1}) + } +} + +func TestAppendUint32EmptyBuf(t *testing.T) { + buf := []byte{} + buf = AppendUint32(buf, 1) + if !reflect.DeepEqual(buf, []byte{0, 0, 0, 1}) { + t.Errorf("AppendUint32(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 1}) + } +} + +func TestAppendUint32BufWithCapacityDoesNotAllocate(t *testing.T) { + buf := make([]byte, 0, 4) + AppendUint32(buf, 1) + buf = buf[0:4] + if !reflect.DeepEqual(buf, []byte{0, 0, 0, 1}) { + t.Errorf("AppendUint32(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 1}) + } +} + +func TestAppendUint64NilBuf(t *testing.T) { + buf := AppendUint64(nil, 1) + if !reflect.DeepEqual(buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) { + t.Errorf("AppendUint64(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) + } +} + +func TestAppendUint64EmptyBuf(t *testing.T) { + buf := []byte{} + buf = AppendUint64(buf, 1) + if !reflect.DeepEqual(buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) { + t.Errorf("AppendUint64(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) + } +} + +func TestAppendUint64BufWithCapacityDoesNotAllocate(t *testing.T) { + buf := make([]byte, 0, 8) + AppendUint64(buf, 1) + buf = buf[0:8] + if !reflect.DeepEqual(buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) { + t.Errorf("AppendUint64(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) + } +} diff --git a/internal/pgmock/pgmock.go b/internal/pgmock/pgmock.go new file mode 100644 index 00000000..c82d7ffc --- /dev/null +++ b/internal/pgmock/pgmock.go @@ -0,0 +1,136 @@ +// Package pgmock provides the ability to mock a PostgreSQL server. +package pgmock + +import ( + "fmt" + "io" + "reflect" + + "github.com/jackc/pgx/v5/pgproto3" +) + +type Step interface { + Step(*pgproto3.Backend) error +} + +type Script struct { + Steps []Step +} + +func (s *Script) Run(backend *pgproto3.Backend) error { + for _, step := range s.Steps { + err := step.Step(backend) + if err != nil { + return err + } + } + + return nil +} + +func (s *Script) Step(backend *pgproto3.Backend) error { + return s.Run(backend) +} + +type expectMessageStep struct { + want pgproto3.FrontendMessage + any bool +} + +func (e *expectMessageStep) Step(backend *pgproto3.Backend) error { + msg, err := backend.Receive() + if err != nil { + return err + } + + if e.any && reflect.TypeOf(msg) == reflect.TypeOf(e.want) { + return nil + } + + if !reflect.DeepEqual(msg, e.want) { + return fmt.Errorf("msg => %#v, e.want => %#v", msg, e.want) + } + + return nil +} + +type expectStartupMessageStep struct { + want *pgproto3.StartupMessage + any bool +} + +func (e *expectStartupMessageStep) Step(backend *pgproto3.Backend) error { + msg, err := backend.ReceiveStartupMessage() + if err != nil { + return err + } + + if e.any { + return nil + } + + if !reflect.DeepEqual(msg, e.want) { + return fmt.Errorf("msg => %#v, e.want => %#v", msg, e.want) + } + + return nil +} + +func ExpectMessage(want pgproto3.FrontendMessage) Step { + return expectMessage(want, false) +} + +func ExpectAnyMessage(want pgproto3.FrontendMessage) Step { + return expectMessage(want, true) +} + +func expectMessage(want pgproto3.FrontendMessage, any bool) Step { + if want, ok := want.(*pgproto3.StartupMessage); ok { + return &expectStartupMessageStep{want: want, any: any} + } + + return &expectMessageStep{want: want, any: any} +} + +type sendMessageStep struct { + msg pgproto3.BackendMessage +} + +func (e *sendMessageStep) Step(backend *pgproto3.Backend) error { + backend.Send(e.msg) + return backend.Flush() +} + +func SendMessage(msg pgproto3.BackendMessage) Step { + return &sendMessageStep{msg: msg} +} + +type waitForCloseMessageStep struct{} + +func (e *waitForCloseMessageStep) Step(backend *pgproto3.Backend) error { + for { + msg, err := backend.Receive() + if err == io.EOF { + return nil + } else if err != nil { + return err + } + + if _, ok := msg.(*pgproto3.Terminate); ok { + return nil + } + } +} + +func WaitForClose() Step { + return &waitForCloseMessageStep{} +} + +func AcceptUnauthenticatedConnRequestSteps() []Step { + return []Step{ + ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}), + SendMessage(&pgproto3.AuthenticationOk{}), + SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}), + SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), + } +} diff --git a/internal/pgmock/pgmock_test.go b/internal/pgmock/pgmock_test.go new file mode 100644 index 00000000..bc787398 --- /dev/null +++ b/internal/pgmock/pgmock_test.go @@ -0,0 +1,93 @@ +package pgmock_test + +import ( + "context" + "fmt" + "net" + "strings" + "testing" + "time" + + "github.com/jackc/pgx/v5/internal/pgmock" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgproto3" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestScript(t *testing.T) { + script := &pgmock.Script{ + Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), + } + script.Steps = append(script.Steps, pgmock.ExpectMessage(&pgproto3.Query{String: "select 42"})) + script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.RowDescription{ + Fields: []pgproto3.FieldDescription{ + pgproto3.FieldDescription{ + Name: []byte("?column?"), + TableOID: 0, + TableAttributeNumber: 0, + DataTypeOID: 23, + DataTypeSize: 4, + TypeModifier: -1, + Format: 0, + }, + }, + })) + script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.DataRow{ + Values: [][]byte{[]byte("42")}, + })) + script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")})) + script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'})) + script.Steps = append(script.Steps, pgmock.ExpectMessage(&pgproto3.Terminate{})) + + ln, err := net.Listen("tcp", "127.0.0.1:") + require.NoError(t, err) + defer ln.Close() + + serverErrChan := make(chan error, 1) + go func() { + defer close(serverErrChan) + + conn, err := ln.Accept() + if err != nil { + serverErrChan <- err + return + } + defer conn.Close() + + err = conn.SetDeadline(time.Now().Add(time.Second)) + if err != nil { + serverErrChan <- err + return + } + + err = script.Run(pgproto3.NewBackend(conn, conn)) + if err != nil { + serverErrChan <- err + return + } + }() + + parts := strings.Split(ln.Addr().String(), ":") + host := parts[0] + port := parts[1] + connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + pgConn, err := pgconn.Connect(ctx, connStr) + require.NoError(t, err) + results, err := pgConn.Exec(ctx, "select 42").ReadAll() + assert.NoError(t, err) + + assert.Len(t, results, 1) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", results[0].CommandTag.String()) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "42", string(results[0].Rows[0][0])) + + pgConn.Close(ctx) + + assert.NoError(t, <-serverErrChan) +} diff --git a/internal/sanitize/sanitize.go b/internal/sanitize/sanitize.go index a7a94e93..3b7bb41f 100644 --- a/internal/sanitize/sanitize.go +++ b/internal/sanitize/sanitize.go @@ -12,13 +12,13 @@ import ( // Part is either a string or an int. A string is raw SQL. An int is a // argument placeholder. -type Part interface{} +type Part any type Query struct { Parts []Part } -func (q *Query) Sanitize(args ...interface{}) (string, error) { +func (q *Query) Sanitize(args ...any) (string, error) { argUse := make([]bool, len(args)) buf := &bytes.Buffer{} @@ -295,7 +295,7 @@ func multilineCommentState(l *sqlLexer) stateFn { // SanitizeSQL replaces placeholder values with args. It quotes and escapes args // as necessary. This function is only safe when standard_conforming_strings is // on. -func SanitizeSQL(sql string, args ...interface{}) (string, error) { +func SanitizeSQL(sql string, args ...any) (string, error) { query, err := NewQuery(sql) if err != nil { return "", err diff --git a/internal/sanitize/sanitize_test.go b/internal/sanitize/sanitize_test.go index bbac84b0..9b5800ec 100644 --- a/internal/sanitize/sanitize_test.go +++ b/internal/sanitize/sanitize_test.go @@ -4,7 +4,7 @@ import ( "testing" "time" - "github.com/jackc/pgx/v4/internal/sanitize" + "github.com/jackc/pgx/v5/internal/sanitize" ) func TestNewQuery(t *testing.T) { @@ -111,57 +111,57 @@ func TestNewQuery(t *testing.T) { func TestQuerySanitize(t *testing.T) { successfulTests := []struct { query sanitize.Query - args []interface{} + args []any expected string }{ { query: sanitize.Query{Parts: []sanitize.Part{"select 42"}}, - args: []interface{}{}, + args: []any{}, expected: `select 42`, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, - args: []interface{}{int64(42)}, + args: []any{int64(42)}, expected: `select 42`, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, - args: []interface{}{float64(1.23)}, + args: []any{float64(1.23)}, expected: `select 1.23`, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, - args: []interface{}{true}, + args: []any{true}, expected: `select true`, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, - args: []interface{}{[]byte{0, 1, 2, 3, 255}}, + args: []any{[]byte{0, 1, 2, 3, 255}}, expected: `select '\x00010203ff'`, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, - args: []interface{}{nil}, + args: []any{nil}, expected: `select null`, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, - args: []interface{}{"foobar"}, + args: []any{"foobar"}, expected: `select 'foobar'`, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, - args: []interface{}{"foo'bar"}, + args: []any{"foo'bar"}, expected: `select 'foo''bar'`, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, - args: []interface{}{`foo\'bar`}, + args: []any{`foo\'bar`}, expected: `select 'foo\''bar'`, }, { query: sanitize.Query{Parts: []sanitize.Part{"insert ", 1}}, - args: []interface{}{time.Date(2020, time.March, 1, 23, 59, 59, 999999999, time.UTC)}, + args: []any{time.Date(2020, time.March, 1, 23, 59, 59, 999999999, time.UTC)}, expected: `insert '2020-03-01 23:59:59.999999Z'`, }, } @@ -180,22 +180,22 @@ func TestQuerySanitize(t *testing.T) { errorTests := []struct { query sanitize.Query - args []interface{} + args []any expected string }{ { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2}}, - args: []interface{}{int64(42)}, + args: []any{int64(42)}, expected: `insufficient arguments`, }, { query: sanitize.Query{Parts: []sanitize.Part{"select 'foo'"}}, - args: []interface{}{int64(42)}, + args: []any{int64(42)}, expected: `unused argument: 0`, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, - args: []interface{}{42}, + args: []any{42}, expected: `invalid arg type: int`, }, } diff --git a/internal/stmtcache/lru_cache.go b/internal/stmtcache/lru_cache.go new file mode 100644 index 00000000..a25cc8b1 --- /dev/null +++ b/internal/stmtcache/lru_cache.go @@ -0,0 +1,98 @@ +package stmtcache + +import ( + "container/list" + + "github.com/jackc/pgx/v5/pgconn" +) + +// LRUCache implements Cache with a Least Recently Used (LRU) cache. +type LRUCache struct { + cap int + m map[string]*list.Element + l *list.List + invalidStmts []*pgconn.StatementDescription +} + +// NewLRUCache creates a new LRUCache. cap is the maximum size of the cache. +func NewLRUCache(cap int) *LRUCache { + return &LRUCache{ + cap: cap, + m: make(map[string]*list.Element), + l: list.New(), + } +} + +// Get returns the statement description for sql. Returns nil if not found. +func (c *LRUCache) Get(key string) *pgconn.StatementDescription { + if el, ok := c.m[key]; ok { + c.l.MoveToFront(el) + return el.Value.(*pgconn.StatementDescription) + } + + return nil + +} + +// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache. +func (c *LRUCache) Put(sd *pgconn.StatementDescription) { + if sd.SQL == "" { + panic("cannot store statement description with empty SQL") + } + + if _, present := c.m[sd.SQL]; present { + return + } + + if c.l.Len() == c.cap { + c.invalidateOldest() + } + + el := c.l.PushFront(sd) + c.m[sd.SQL] = el +} + +// Invalidate invalidates statement description identified by sql. Does nothing if not found. +func (c *LRUCache) Invalidate(sql string) { + if el, ok := c.m[sql]; ok { + delete(c.m, sql) + c.invalidStmts = append(c.invalidStmts, el.Value.(*pgconn.StatementDescription)) + c.l.Remove(el) + } +} + +// InvalidateAll invalidates all statement descriptions. +func (c *LRUCache) InvalidateAll() { + el := c.l.Front() + for el != nil { + c.invalidStmts = append(c.invalidStmts, el.Value.(*pgconn.StatementDescription)) + el = el.Next() + } + + c.m = make(map[string]*list.Element) + c.l = list.New() +} + +func (c *LRUCache) HandleInvalidated() []*pgconn.StatementDescription { + invalidStmts := c.invalidStmts + c.invalidStmts = nil + return invalidStmts +} + +// Len returns the number of cached prepared statement descriptions. +func (c *LRUCache) Len() int { + return c.l.Len() +} + +// Cap returns the maximum number of cached prepared statement descriptions. +func (c *LRUCache) Cap() int { + return c.cap +} + +func (c *LRUCache) invalidateOldest() { + oldest := c.l.Back() + sd := oldest.Value.(*pgconn.StatementDescription) + c.invalidStmts = append(c.invalidStmts, sd) + delete(c.m, sd.SQL) + c.l.Remove(oldest) +} diff --git a/internal/stmtcache/stmtcache.go b/internal/stmtcache/stmtcache.go new file mode 100644 index 00000000..e1bdcba5 --- /dev/null +++ b/internal/stmtcache/stmtcache.go @@ -0,0 +1,57 @@ +// Package stmtcache is a cache for statement descriptions. +package stmtcache + +import ( + "strconv" + "sync/atomic" + + "github.com/jackc/pgx/v5/pgconn" +) + +var stmtCounter int64 + +// NextStatementName returns a statement name that will be unique for the lifetime of the program. +func NextStatementName() string { + n := atomic.AddInt64(&stmtCounter, 1) + return "stmtcache_" + strconv.FormatInt(n, 10) +} + +// Cache caches statement descriptions. +type Cache interface { + // Get returns the statement description for sql. Returns nil if not found. + Get(sql string) *pgconn.StatementDescription + + // Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache. + Put(sd *pgconn.StatementDescription) + + // Invalidate invalidates statement description identified by sql. Does nothing if not found. + Invalidate(sql string) + + // InvalidateAll invalidates all statement descriptions. + InvalidateAll() + + // HandleInvalidated returns a slice of all statement descriptions invalidated since the last call to HandleInvalidated. + HandleInvalidated() []*pgconn.StatementDescription + + // Len returns the number of cached prepared statement descriptions. + Len() int + + // Cap returns the maximum number of cached prepared statement descriptions. + Cap() int +} + +func IsStatementInvalid(err error) bool { + pgErr, ok := err.(*pgconn.PgError) + if !ok { + return false + } + + // https://github.com/jackc/pgx/issues/1162 + // + // We used to look for the message "cached plan must not change result type". However, that message can be localized. + // Unfortunately, error code "0A000" - "FEATURE NOT SUPPORTED" is used for many different errors and the only way to + // tell the difference is by the message. But all that happens is we clear a statement that we otherwise wouldn't + // have so it should be safe. + possibleInvalidCachedPlanError := pgErr.Code == "0A000" + return possibleInvalidCachedPlanError +} diff --git a/internal/stmtcache/unlimited_cache.go b/internal/stmtcache/unlimited_cache.go new file mode 100644 index 00000000..f5f59396 --- /dev/null +++ b/internal/stmtcache/unlimited_cache.go @@ -0,0 +1,71 @@ +package stmtcache + +import ( + "math" + + "github.com/jackc/pgx/v5/pgconn" +) + +// UnlimitedCache implements Cache with no capacity limit. +type UnlimitedCache struct { + m map[string]*pgconn.StatementDescription + invalidStmts []*pgconn.StatementDescription +} + +// NewUnlimitedCache creates a new UnlimitedCache. +func NewUnlimitedCache() *UnlimitedCache { + return &UnlimitedCache{ + m: make(map[string]*pgconn.StatementDescription), + } +} + +// Get returns the statement description for sql. Returns nil if not found. +func (c *UnlimitedCache) Get(sql string) *pgconn.StatementDescription { + return c.m[sql] +} + +// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache. +func (c *UnlimitedCache) Put(sd *pgconn.StatementDescription) { + if sd.SQL == "" { + panic("cannot store statement description with empty SQL") + } + + if _, present := c.m[sd.SQL]; present { + return + } + + c.m[sd.SQL] = sd +} + +// Invalidate invalidates statement description identified by sql. Does nothing if not found. +func (c *UnlimitedCache) Invalidate(sql string) { + if sd, ok := c.m[sql]; ok { + delete(c.m, sql) + c.invalidStmts = append(c.invalidStmts, sd) + } +} + +// InvalidateAll invalidates all statement descriptions. +func (c *UnlimitedCache) InvalidateAll() { + for _, sd := range c.m { + c.invalidStmts = append(c.invalidStmts, sd) + } + + c.m = make(map[string]*pgconn.StatementDescription) +} + +func (c *UnlimitedCache) HandleInvalidated() []*pgconn.StatementDescription { + invalidStmts := c.invalidStmts + c.invalidStmts = nil + return invalidStmts +} + +// Len returns the number of cached prepared statement descriptions. +func (c *UnlimitedCache) Len() int { + return len(c.m) +} + +// Cap returns the maximum number of cached prepared statement descriptions. +func (c *UnlimitedCache) Cap() int { + return math.MaxInt +} diff --git a/large_objects_test.go b/large_objects_test.go index 672729ee..626809e7 100644 --- a/large_objects_test.go +++ b/large_objects_test.go @@ -7,8 +7,9 @@ import ( "testing" "time" - "github.com/jackc/pgconn" - "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxtest" ) func TestLargeObjects(t *testing.T) { @@ -22,7 +23,7 @@ func TestLargeObjects(t *testing.T) { t.Fatal(err) } - skipCockroachDB(t, conn, "Server does support large objects") + pgxtest.SkipCockroachDB(t, conn, "Server does support large objects") tx, err := conn.Begin(ctx) if err != nil { @@ -32,7 +33,7 @@ func TestLargeObjects(t *testing.T) { testLargeObjects(t, ctx, tx) } -func TestLargeObjectsPreferSimpleProtocol(t *testing.T) { +func TestLargeObjectsSimpleProtocol(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) @@ -43,14 +44,14 @@ func TestLargeObjectsPreferSimpleProtocol(t *testing.T) { t.Fatal(err) } - config.PreferSimpleProtocol = true + config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol conn, err := pgx.ConnectConfig(ctx, config) if err != nil { t.Fatal(err) } - skipCockroachDB(t, conn, "Server does support large objects") + pgxtest.SkipCockroachDB(t, conn, "Server does support large objects") tx, err := conn.Begin(ctx) if err != nil { @@ -169,7 +170,7 @@ func TestLargeObjectsMultipleTransactions(t *testing.T) { t.Fatal(err) } - skipCockroachDB(t, conn, "Server does support large objects") + pgxtest.SkipCockroachDB(t, conn, "Server does support large objects") tx, err := conn.Begin(ctx) if err != nil { diff --git a/log/kitlogadapter/adapter.go b/log/kitlogadapter/adapter.go deleted file mode 100644 index 0a46197f..00000000 --- a/log/kitlogadapter/adapter.go +++ /dev/null @@ -1,39 +0,0 @@ -package kitlogadapter - -import ( - "context" - - "github.com/go-kit/log" - kitlevel "github.com/go-kit/log/level" - "github.com/jackc/pgx/v4" -) - -type Logger struct { - l log.Logger -} - -func NewLogger(l log.Logger) *Logger { - return &Logger{l: l} -} - -func (l *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { - logger := l.l - for k, v := range data { - logger = log.With(logger, k, v) - } - - switch level { - case pgx.LogLevelTrace: - logger.Log("PGX_LOG_LEVEL", level, "msg", msg) - case pgx.LogLevelDebug: - kitlevel.Debug(logger).Log("msg", msg) - case pgx.LogLevelInfo: - kitlevel.Info(logger).Log("msg", msg) - case pgx.LogLevelWarn: - kitlevel.Warn(logger).Log("msg", msg) - case pgx.LogLevelError: - kitlevel.Error(logger).Log("msg", msg) - default: - logger.Log("INVALID_PGX_LOG_LEVEL", level, "error", msg) - } -} diff --git a/log/log15adapter/adapter.go b/log/log15adapter/adapter.go deleted file mode 100644 index 70608e33..00000000 --- a/log/log15adapter/adapter.go +++ /dev/null @@ -1,49 +0,0 @@ -// Package log15adapter provides a logger that writes to a github.com/inconshreveable/log15.Logger -// log. -package log15adapter - -import ( - "context" - - "github.com/jackc/pgx/v4" -) - -// Log15Logger interface defines the subset of -// github.com/inconshreveable/log15.Logger that this adapter uses. -type Log15Logger interface { - Debug(msg string, ctx ...interface{}) - Info(msg string, ctx ...interface{}) - Warn(msg string, ctx ...interface{}) - Error(msg string, ctx ...interface{}) - Crit(msg string, ctx ...interface{}) -} - -type Logger struct { - l Log15Logger -} - -func NewLogger(l Log15Logger) *Logger { - return &Logger{l: l} -} - -func (l *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { - logArgs := make([]interface{}, 0, len(data)) - for k, v := range data { - logArgs = append(logArgs, k, v) - } - - switch level { - case pgx.LogLevelTrace: - l.l.Debug(msg, append(logArgs, "PGX_LOG_LEVEL", level)...) - case pgx.LogLevelDebug: - l.l.Debug(msg, logArgs...) - case pgx.LogLevelInfo: - l.l.Info(msg, logArgs...) - case pgx.LogLevelWarn: - l.l.Warn(msg, logArgs...) - case pgx.LogLevelError: - l.l.Error(msg, logArgs...) - default: - l.l.Error(msg, append(logArgs, "INVALID_PGX_LOG_LEVEL", level)...) - } -} diff --git a/log/logrusadapter/adapter.go b/log/logrusadapter/adapter.go deleted file mode 100644 index e0cd6328..00000000 --- a/log/logrusadapter/adapter.go +++ /dev/null @@ -1,42 +0,0 @@ -// Package logrusadapter provides a logger that writes to a github.com/sirupsen/logrus.Logger -// log. -package logrusadapter - -import ( - "context" - - "github.com/jackc/pgx/v4" - "github.com/sirupsen/logrus" -) - -type Logger struct { - l logrus.FieldLogger -} - -func NewLogger(l logrus.FieldLogger) *Logger { - return &Logger{l: l} -} - -func (l *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { - var logger logrus.FieldLogger - if data != nil { - logger = l.l.WithFields(data) - } else { - logger = l.l - } - - switch level { - case pgx.LogLevelTrace: - logger.WithField("PGX_LOG_LEVEL", level).Debug(msg) - case pgx.LogLevelDebug: - logger.Debug(msg) - case pgx.LogLevelInfo: - logger.Info(msg) - case pgx.LogLevelWarn: - logger.Warn(msg) - case pgx.LogLevelError: - logger.Error(msg) - default: - logger.WithField("INVALID_PGX_LOG_LEVEL", level).Error(msg) - } -} diff --git a/log/testingadapter/adapter.go b/log/testingadapter/adapter.go index 3ddce5a1..c901a6a6 100644 --- a/log/testingadapter/adapter.go +++ b/log/testingadapter/adapter.go @@ -6,13 +6,13 @@ import ( "context" "fmt" - "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v5/tracelog" ) // TestingLogger interface defines the subset of testing.TB methods used by this // adapter. type TestingLogger interface { - Log(args ...interface{}) + Log(args ...any) } type Logger struct { @@ -23,8 +23,8 @@ func NewLogger(l TestingLogger) *Logger { return &Logger{l: l} } -func (l *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { - logArgs := make([]interface{}, 0, 2+len(data)) +func (l *Logger) Log(ctx context.Context, level tracelog.LogLevel, msg string, data map[string]any) { + logArgs := make([]any, 0, 2+len(data)) logArgs = append(logArgs, level, msg) for k, v := range data { logArgs = append(logArgs, fmt.Sprintf("%s=%v", k, v)) diff --git a/log/zapadapter/adapter.go b/log/zapadapter/adapter.go deleted file mode 100644 index ebc540aa..00000000 --- a/log/zapadapter/adapter.go +++ /dev/null @@ -1,42 +0,0 @@ -// Package zapadapter provides a logger that writes to a go.uber.org/zap.Logger. -package zapadapter - -import ( - "context" - - "github.com/jackc/pgx/v4" - "go.uber.org/zap" - "go.uber.org/zap/zapcore" -) - -type Logger struct { - logger *zap.Logger -} - -func NewLogger(logger *zap.Logger) *Logger { - return &Logger{logger: logger.WithOptions(zap.AddCallerSkip(1))} -} - -func (pl *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { - fields := make([]zapcore.Field, len(data)) - i := 0 - for k, v := range data { - fields[i] = zap.Any(k, v) - i++ - } - - switch level { - case pgx.LogLevelTrace: - pl.logger.Debug(msg, append(fields, zap.Stringer("PGX_LOG_LEVEL", level))...) - case pgx.LogLevelDebug: - pl.logger.Debug(msg, fields...) - case pgx.LogLevelInfo: - pl.logger.Info(msg, fields...) - case pgx.LogLevelWarn: - pl.logger.Warn(msg, fields...) - case pgx.LogLevelError: - pl.logger.Error(msg, fields...) - default: - pl.logger.Error(msg, append(fields, zap.Stringer("PGX_LOG_LEVEL", level))...) - } -} diff --git a/log/zerologadapter/adapter.go b/log/zerologadapter/adapter.go deleted file mode 100644 index 6e8b4b94..00000000 --- a/log/zerologadapter/adapter.go +++ /dev/null @@ -1,102 +0,0 @@ -// Package zerologadapter provides a logger that writes to a github.com/rs/zerolog. -package zerologadapter - -import ( - "context" - - "github.com/jackc/pgx/v4" - "github.com/rs/zerolog" -) - -type Logger struct { - logger zerolog.Logger - withFunc func(context.Context, zerolog.Context) zerolog.Context - fromContext bool - skipModule bool -} - -// option options for configuring the logger when creating a new logger. -type option func(logger *Logger) - -// WithContextFunc adds possibility to get request scoped values from the -// ctx.Context before logging lines. -func WithContextFunc(withFunc func(context.Context, zerolog.Context) zerolog.Context) option { - return func(logger *Logger) { - logger.withFunc = withFunc - } -} - -// WithoutPGXModule disables adding module:pgx to the default logger context. -func WithoutPGXModule() option { - return func(logger *Logger) { - logger.skipModule = true - } -} - -// NewLogger accepts a zerolog.Logger as input and returns a new custom pgx -// logging facade as output. -func NewLogger(logger zerolog.Logger, options ...option) *Logger { - l := Logger{ - logger: logger, - } - l.init(options) - return &l -} - -// NewContextLogger creates logger that extracts the zerolog.Logger from the -// context.Context by using `zerolog.Ctx`. The zerolog.DefaultContextLogger will -// be used if no logger is associated with the context. -func NewContextLogger(options ...option) *Logger { - l := Logger{ - fromContext: true, - } - l.init(options) - return &l -} - -func (pl *Logger) init(options []option) { - for _, opt := range options { - opt(pl) - } - if !pl.skipModule { - pl.logger = pl.logger.With().Str("module", "pgx").Logger() - } -} - -func (pl *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { - var zlevel zerolog.Level - switch level { - case pgx.LogLevelNone: - zlevel = zerolog.NoLevel - case pgx.LogLevelError: - zlevel = zerolog.ErrorLevel - case pgx.LogLevelWarn: - zlevel = zerolog.WarnLevel - case pgx.LogLevelInfo: - zlevel = zerolog.InfoLevel - case pgx.LogLevelDebug: - zlevel = zerolog.DebugLevel - default: - zlevel = zerolog.DebugLevel - } - - var zctx zerolog.Context - if pl.fromContext { - logger := zerolog.Ctx(ctx) - zctx = logger.With() - } else { - zctx = pl.logger.With() - } - if pl.withFunc != nil { - zctx = pl.withFunc(ctx, zctx) - } - - pgxlog := zctx.Logger() - event := pgxlog.WithLevel(zlevel) - if event.Enabled() { - if pl.fromContext && !pl.skipModule { - event.Str("module", "pgx") - } - event.Fields(data).Msg(msg) - } -} diff --git a/log/zerologadapter/adapter_test.go b/log/zerologadapter/adapter_test.go deleted file mode 100644 index 3a11cbc0..00000000 --- a/log/zerologadapter/adapter_test.go +++ /dev/null @@ -1,98 +0,0 @@ -package zerologadapter_test - -import ( - "bytes" - "context" - "testing" - - "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/log/zerologadapter" - "github.com/rs/zerolog" -) - -func TestLogger(t *testing.T) { - - t.Run("default", func(t *testing.T) { - var buf bytes.Buffer - zlogger := zerolog.New(&buf) - logger := zerologadapter.NewLogger(zlogger) - logger.Log(context.Background(), pgx.LogLevelInfo, "hello", map[string]interface{}{"one": "two"}) - const want = `{"level":"info","module":"pgx","one":"two","message":"hello"} -` - got := buf.String() - if got != want { - t.Errorf("%s != %s", got, want) - } - }) - - t.Run("disable pgx module", func(t *testing.T) { - var buf bytes.Buffer - zlogger := zerolog.New(&buf) - logger := zerologadapter.NewLogger(zlogger, zerologadapter.WithoutPGXModule()) - logger.Log(context.Background(), pgx.LogLevelInfo, "hello", nil) - const want = `{"level":"info","message":"hello"} -` - got := buf.String() - if got != want { - t.Errorf("%s != %s", got, want) - } - }) - - t.Run("from context", func(t *testing.T) { - var buf bytes.Buffer - zlogger := zerolog.New(&buf) - ctx := zlogger.WithContext(context.Background()) - logger := zerologadapter.NewContextLogger() - logger.Log(ctx, pgx.LogLevelInfo, "hello", map[string]interface{}{"one": "two"}) - const want = `{"level":"info","module":"pgx","one":"two","message":"hello"} -` - - got := buf.String() - if got != want { - t.Log(got) - t.Log(want) - t.Errorf("%s != %s", got, want) - } - }) - - var buf bytes.Buffer - type key string - var ck key - zlogger := zerolog.New(&buf) - logger := zerologadapter.NewLogger(zlogger, - zerologadapter.WithContextFunc(func(ctx context.Context, logWith zerolog.Context) zerolog.Context { - // You can use zerolog.hlog.IDFromCtx(ctx) or even - // zerolog.log.Ctx(ctx) to fetch the whole logger instance from the - // context if you want. - id, ok := ctx.Value(ck).(string) - if ok { - logWith = logWith.Str("req_id", id) - } - return logWith - }), - ) - - t.Run("no request id", func(t *testing.T) { - buf.Reset() - ctx := context.Background() - logger.Log(ctx, pgx.LogLevelInfo, "hello", nil) - const want = `{"level":"info","module":"pgx","message":"hello"} -` - got := buf.String() - if got != want { - t.Errorf("%s != %s", got, want) - } - }) - - t.Run("with request id", func(t *testing.T) { - buf.Reset() - ctx := context.WithValue(context.Background(), ck, "1") - logger.Log(ctx, pgx.LogLevelInfo, "hello", map[string]interface{}{"two": "2"}) - const want = `{"level":"info","module":"pgx","req_id":"1","two":"2","message":"hello"} -` - got := buf.String() - if got != want { - t.Errorf("%s != %s", got, want) - } - }) -} diff --git a/logger.go b/logger.go deleted file mode 100644 index 41f8b7e8..00000000 --- a/logger.go +++ /dev/null @@ -1,107 +0,0 @@ -package pgx - -import ( - "context" - "encoding/hex" - "errors" - "fmt" -) - -// The values for log levels are chosen such that the zero value means that no -// log level was specified. -const ( - LogLevelTrace = 6 - LogLevelDebug = 5 - LogLevelInfo = 4 - LogLevelWarn = 3 - LogLevelError = 2 - LogLevelNone = 1 -) - -// LogLevel represents the pgx logging level. See LogLevel* constants for -// possible values. -type LogLevel int - -func (ll LogLevel) String() string { - switch ll { - case LogLevelTrace: - return "trace" - case LogLevelDebug: - return "debug" - case LogLevelInfo: - return "info" - case LogLevelWarn: - return "warn" - case LogLevelError: - return "error" - case LogLevelNone: - return "none" - default: - return fmt.Sprintf("invalid level %d", ll) - } -} - -// Logger is the interface used to get logging from pgx internals. -type Logger interface { - // Log a message at the given level with data key/value pairs. data may be nil. - Log(ctx context.Context, level LogLevel, msg string, data map[string]interface{}) -} - -// LoggerFunc is a wrapper around a function to satisfy the pgx.Logger interface -type LoggerFunc func(ctx context.Context, level LogLevel, msg string, data map[string]interface{}) - -// Log delegates the logging request to the wrapped function -func (f LoggerFunc) Log(ctx context.Context, level LogLevel, msg string, data map[string]interface{}) { - f(ctx, level, msg, data) -} - -// LogLevelFromString converts log level string to constant -// -// Valid levels: -// -// trace -// debug -// info -// warn -// error -// none -func LogLevelFromString(s string) (LogLevel, error) { - switch s { - case "trace": - return LogLevelTrace, nil - case "debug": - return LogLevelDebug, nil - case "info": - return LogLevelInfo, nil - case "warn": - return LogLevelWarn, nil - case "error": - return LogLevelError, nil - case "none": - return LogLevelNone, nil - default: - return 0, errors.New("invalid log level") - } -} - -func logQueryArgs(args []interface{}) []interface{} { - logArgs := make([]interface{}, 0, len(args)) - - for _, a := range args { - switch v := a.(type) { - case []byte: - if len(v) < 64 { - a = hex.EncodeToString(v) - } else { - a = fmt.Sprintf("%x (truncated %d bytes)", v[:64], len(v)-64) - } - case string: - if len(v) > 64 { - a = fmt.Sprintf("%s (truncated %d bytes)", v[:64], len(v)-64) - } - } - logArgs = append(logArgs, a) - } - - return logArgs -} diff --git a/messages.go b/messages.go deleted file mode 100644 index 5324cbb5..00000000 --- a/messages.go +++ /dev/null @@ -1,23 +0,0 @@ -package pgx - -import ( - "database/sql/driver" - - "github.com/jackc/pgtype" -) - -func convertDriverValuers(args []interface{}) ([]interface{}, error) { - for i, arg := range args { - switch arg := arg.(type) { - case pgtype.BinaryEncoder: - case pgtype.TextEncoder: - case driver.Valuer: - v, err := callValuerValue(arg) - if err != nil { - return nil, err - } - args[i] = v - } - } - return args, nil -} diff --git a/named_args.go b/named_args.go new file mode 100644 index 00000000..3d91367b --- /dev/null +++ b/named_args.go @@ -0,0 +1,266 @@ +package pgx + +import ( + "context" + "strconv" + "strings" + "unicode/utf8" +) + +// NamedArgs can be used as the first argument to a query method. It will replace every '@' named placeholder with a '$' +// ordinal placeholder and construct the appropriate arguments. +// +// For example, the following two queries are equivalent: +// +// conn.Query(ctx, "select * from widgets where foo = @foo and bar = @bar", pgx.NamedArgs{"foo": 1, "bar": 2})) +// conn.Query(ctx, "select * from widgets where foo = $1 and bar = $2", 1, 2})) +type NamedArgs map[string]any + +// RewriteQuery implements the QueryRewriter interface. +func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any) { + l := &sqlLexer{ + src: sql, + stateFn: rawState, + nameToOrdinal: make(map[namedArg]int, len(na)), + } + + for l.stateFn != nil { + l.stateFn = l.stateFn(l) + } + + sb := strings.Builder{} + for _, p := range l.parts { + switch p := p.(type) { + case string: + sb.WriteString(p) + case namedArg: + sb.WriteRune('$') + sb.WriteString(strconv.Itoa(l.nameToOrdinal[p])) + } + } + + newArgs = make([]any, len(l.nameToOrdinal)) + for name, ordinal := range l.nameToOrdinal { + newArgs[ordinal-1] = na[string(name)] + } + + return sb.String(), newArgs +} + +type namedArg string + +type sqlLexer struct { + src string + start int + pos int + nested int // multiline comment nesting level. + stateFn stateFn + parts []any + + nameToOrdinal map[namedArg]int +} + +type stateFn func(*sqlLexer) stateFn + +func rawState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case 'e', 'E': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '\'' { + l.pos += width + return escapeStringState + } + case '\'': + return singleQuoteState + case '"': + return doubleQuoteState + case '@': + nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:]) + if isLetter(nextRune) { + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos-width]) + } + l.start = l.pos + return namedArgState + } + case '-': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '-' { + l.pos += width + return oneLineCommentState + } + case '/': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '*' { + l.pos += width + return multilineCommentState + } + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} + +func isLetter(r rune) bool { + return (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') +} + +func namedArgState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + if r == utf8.RuneError { + if l.pos-l.start > 0 { + na := namedArg(l.src[l.start:l.pos]) + if _, found := l.nameToOrdinal[na]; !found { + l.nameToOrdinal[na] = len(l.nameToOrdinal) + 1 + } + l.parts = append(l.parts, na) + l.start = l.pos + } + return nil + } else if !(isLetter(r) || (r >= '0' && r <= '9') || r == '_') { + l.pos -= width + na := namedArg(l.src[l.start:l.pos]) + if _, found := l.nameToOrdinal[na]; !found { + l.nameToOrdinal[na] = len(l.nameToOrdinal) + 1 + } + l.parts = append(l.parts, namedArg(na)) + l.start = l.pos + return rawState + } + } +} + +func singleQuoteState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '\'': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '\'' { + return rawState + } + l.pos += width + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} + +func doubleQuoteState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '"': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '"' { + return rawState + } + l.pos += width + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} + +func escapeStringState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '\\': + _, width = utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + case '\'': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '\'' { + return rawState + } + l.pos += width + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} + +func oneLineCommentState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '\\': + _, width = utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + case '\n', '\r': + return rawState + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} + +func multilineCommentState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '/': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '*' { + l.pos += width + l.nested++ + } + case '*': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '/' { + continue + } + + l.pos += width + if l.nested == 0 { + return rawState + } + l.nested-- + + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} diff --git a/named_args_test.go b/named_args_test.go new file mode 100644 index 00000000..116e03dc --- /dev/null +++ b/named_args_test.go @@ -0,0 +1,102 @@ +package pgx_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/stretchr/testify/assert" +) + +func TestNamedArgsRewriteQuery(t *testing.T) { + t.Parallel() + + for i, tt := range []struct { + sql string + args []any + namedArgs pgx.NamedArgs + expectedSQL string + expectedArgs []any + }{ + { + sql: "select * from users where id = @id", + namedArgs: pgx.NamedArgs{"id": int32(42)}, + expectedSQL: "select * from users where id = $1", + expectedArgs: []any{int32(42)}, + }, + { + sql: "select * from t where foo < @abc and baz = @def and bar < @abc", + namedArgs: pgx.NamedArgs{"abc": int32(42), "def": int32(1)}, + expectedSQL: "select * from t where foo < $1 and baz = $2 and bar < $1", + expectedArgs: []any{int32(42), int32(1)}, + }, + { + sql: "select @a::int, @b::text", + namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"}, + expectedSQL: "select $1::int, $2::text", + expectedArgs: []any{int32(42), "foo"}, + }, + { + sql: "select @Abc::int, @b_4::text", + namedArgs: pgx.NamedArgs{"Abc": int32(42), "b_4": "foo"}, + expectedSQL: "select $1::int, $2::text", + expectedArgs: []any{int32(42), "foo"}, + }, + { + sql: "at end @", + namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"}, + expectedSQL: "at end @", + expectedArgs: []any{}, + }, + { + sql: "ignores without letter after @ foo bar", + namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"}, + expectedSQL: "ignores without letter after @ foo bar", + expectedArgs: []any{}, + }, + { + sql: "name must start with letter @1 foo bar", + namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"}, + expectedSQL: "name must start with letter @1 foo bar", + expectedArgs: []any{}, + }, + { + sql: `select *, '@foo' as "@bar" from users where id = @id`, + namedArgs: pgx.NamedArgs{"id": int32(42)}, + expectedSQL: `select *, '@foo' as "@bar" from users where id = $1`, + expectedArgs: []any{int32(42)}, + }, + { + sql: `select * -- @foo + from users -- @single line comments + where id = @id;`, + namedArgs: pgx.NamedArgs{"id": int32(42)}, + expectedSQL: `select * -- @foo + from users -- @single line comments + where id = $1;`, + expectedArgs: []any{int32(42)}, + }, + { + sql: `select * /* @multi line + @comment + */ + /* /* with @nesting */ */ + from users + where id = @id;`, + namedArgs: pgx.NamedArgs{"id": int32(42)}, + expectedSQL: `select * /* @multi line + @comment + */ + /* /* with @nesting */ */ + from users + where id = $1;`, + expectedArgs: []any{int32(42)}, + }, + + // test comments and quotes + } { + sql, args := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, tt.args) + assert.Equalf(t, tt.expectedSQL, sql, "%d", i) + assert.Equalf(t, tt.expectedArgs, args, "%d", i) + } +} diff --git a/pgbouncer_test.go b/pgbouncer_test.go index e3fa4d0c..ac22b679 100644 --- a/pgbouncer_test.go +++ b/pgbouncer_test.go @@ -5,9 +5,7 @@ import ( "os" "testing" - "github.com/jackc/pgconn" - "github.com/jackc/pgconn/stmtcache" - "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -19,9 +17,8 @@ func TestPgbouncerStatementCacheDescribe(t *testing.T) { } config := mustParseConfig(t, connString) - config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache { - return stmtcache.New(conn, stmtcache.ModeDescribe, 1024) - } + config.DefaultQueryExecMode = pgx.QueryExecModeCacheDescribe + config.DescriptionCacheCapacity = 1024 testPgbouncer(t, config, 10, 100) } @@ -33,8 +30,7 @@ func TestPgbouncerSimpleProtocol(t *testing.T) { } config := mustParseConfig(t, connString) - config.BuildStatementCache = nil - config.PreferSimpleProtocol = true + config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol testPgbouncer(t, config, 10, 100) } diff --git a/pgconn/README.md b/pgconn/README.md new file mode 100644 index 00000000..4f0349f2 --- /dev/null +++ b/pgconn/README.md @@ -0,0 +1,53 @@ +# pgconn + +Package pgconn is a low-level PostgreSQL database driver. It operates at nearly the same level as the C library libpq. +It is primarily intended to serve as the foundation for higher level libraries such as https://github.com/jackc/pgx. +Applications should handle normal queries with a higher level library and only use pgconn directly when required for +low-level access to PostgreSQL functionality. + +## Example Usage + +```go +pgConn, err := pgconn.Connect(context.Background(), os.Getenv("DATABASE_URL")) +if err != nil { + log.Fatalln("pgconn failed to connect:", err) +} +defer pgConn.Close(context.Background()) + +result := pgConn.ExecParams(context.Background(), "SELECT email FROM users WHERE id=$1", [][]byte{[]byte("123")}, nil, nil, nil) +for result.NextRow() { + fmt.Println("User 123 has email:", string(result.Values()[0])) +} +_, err = result.Close() +if err != nil { + log.Fatalln("failed reading result:", err) +} +``` + +## Testing + +The pgconn tests require a PostgreSQL database. It will connect to the database specified in the `PGX_TEST_CONN_STRING` +environment variable. The `PGX_TEST_CONN_STRING` environment variable can be a URL or DSN. In addition, the standard `PG*` +environment variables will be respected. Consider using [direnv](https://github.com/direnv/direnv) to simplify +environment variable handling. + +### Example Test Environment + +Connect to your PostgreSQL server and run: + +``` +create database pgx_test; +``` + +Now you can run the tests: + +```bash +PGX_TEST_CONN_STRING="host=/var/run/postgresql dbname=pgx_test" go test ./... +``` + +### Connection and Authentication Tests + +Pgconn supports multiple connection types and means of authentication. These tests are optional. They +will only run if the appropriate environment variable is set. Run `go test -v | grep SKIP` to see if any tests are being +skipped. Most developers will not need to enable these tests. See `ci/setup_test.bash` for an example set up if you need change +authentication code. diff --git a/pgconn/auth_scram.go b/pgconn/auth_scram.go new file mode 100644 index 00000000..6ca9e337 --- /dev/null +++ b/pgconn/auth_scram.go @@ -0,0 +1,272 @@ +// SCRAM-SHA-256 authentication +// +// Resources: +// https://tools.ietf.org/html/rfc5802 +// https://tools.ietf.org/html/rfc8265 +// https://www.postgresql.org/docs/current/sasl-authentication.html +// +// Inspiration drawn from other implementations: +// https://github.com/lib/pq/pull/608 +// https://github.com/lib/pq/pull/788 +// https://github.com/lib/pq/pull/833 + +package pgconn + +import ( + "bytes" + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "errors" + "fmt" + "strconv" + + "github.com/jackc/pgx/v5/pgproto3" + "golang.org/x/crypto/pbkdf2" + "golang.org/x/text/secure/precis" +) + +const clientNonceLen = 18 + +// Perform SCRAM authentication. +func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { + sc, err := newScramClient(serverAuthMechanisms, c.config.Password) + if err != nil { + return err + } + + // Send client-first-message in a SASLInitialResponse + saslInitialResponse := &pgproto3.SASLInitialResponse{ + AuthMechanism: "SCRAM-SHA-256", + Data: sc.clientFirstMessage(), + } + c.frontend.Send(saslInitialResponse) + err = c.frontend.Flush() + if err != nil { + return err + } + + // Receive server-first-message payload in a AuthenticationSASLContinue. + saslContinue, err := c.rxSASLContinue() + if err != nil { + return err + } + err = sc.recvServerFirstMessage(saslContinue.Data) + if err != nil { + return err + } + + // Send client-final-message in a SASLResponse + saslResponse := &pgproto3.SASLResponse{ + Data: []byte(sc.clientFinalMessage()), + } + c.frontend.Send(saslResponse) + err = c.frontend.Flush() + if err != nil { + return err + } + + // Receive server-final-message payload in a AuthenticationSASLFinal. + saslFinal, err := c.rxSASLFinal() + if err != nil { + return err + } + return sc.recvServerFinalMessage(saslFinal.Data) +} + +func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error) { + msg, err := c.receiveMessage() + if err != nil { + return nil, err + } + switch m := msg.(type) { + case *pgproto3.AuthenticationSASLContinue: + return m, nil + case *pgproto3.ErrorResponse: + return nil, ErrorResponseToPgError(m) + } + + return nil, fmt.Errorf("expected AuthenticationSASLContinue message but received unexpected message %T", msg) +} + +func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) { + msg, err := c.receiveMessage() + if err != nil { + return nil, err + } + switch m := msg.(type) { + case *pgproto3.AuthenticationSASLFinal: + return m, nil + case *pgproto3.ErrorResponse: + return nil, ErrorResponseToPgError(m) + } + + return nil, fmt.Errorf("expected AuthenticationSASLFinal message but received unexpected message %T", msg) +} + +type scramClient struct { + serverAuthMechanisms []string + password []byte + clientNonce []byte + + clientFirstMessageBare []byte + + serverFirstMessage []byte + clientAndServerNonce []byte + salt []byte + iterations int + + saltedPassword []byte + authMessage []byte +} + +func newScramClient(serverAuthMechanisms []string, password string) (*scramClient, error) { + sc := &scramClient{ + serverAuthMechanisms: serverAuthMechanisms, + } + + // Ensure server supports SCRAM-SHA-256 + hasScramSHA256 := false + for _, mech := range sc.serverAuthMechanisms { + if mech == "SCRAM-SHA-256" { + hasScramSHA256 = true + break + } + } + if !hasScramSHA256 { + return nil, errors.New("server does not support SCRAM-SHA-256") + } + + // precis.OpaqueString is equivalent to SASLprep for password. + var err error + sc.password, err = precis.OpaqueString.Bytes([]byte(password)) + if err != nil { + // PostgreSQL allows passwords invalid according to SCRAM / SASLprep. + sc.password = []byte(password) + } + + buf := make([]byte, clientNonceLen) + _, err = rand.Read(buf) + if err != nil { + return nil, err + } + sc.clientNonce = make([]byte, base64.RawStdEncoding.EncodedLen(len(buf))) + base64.RawStdEncoding.Encode(sc.clientNonce, buf) + + return sc, nil +} + +func (sc *scramClient) clientFirstMessage() []byte { + sc.clientFirstMessageBare = []byte(fmt.Sprintf("n=,r=%s", sc.clientNonce)) + return []byte(fmt.Sprintf("n,,%s", sc.clientFirstMessageBare)) +} + +func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error { + sc.serverFirstMessage = serverFirstMessage + buf := serverFirstMessage + if !bytes.HasPrefix(buf, []byte("r=")) { + return errors.New("invalid SCRAM server-first-message received from server: did not include r=") + } + buf = buf[2:] + + idx := bytes.IndexByte(buf, ',') + if idx == -1 { + return errors.New("invalid SCRAM server-first-message received from server: did not include s=") + } + sc.clientAndServerNonce = buf[:idx] + buf = buf[idx+1:] + + if !bytes.HasPrefix(buf, []byte("s=")) { + return errors.New("invalid SCRAM server-first-message received from server: did not include s=") + } + buf = buf[2:] + + idx = bytes.IndexByte(buf, ',') + if idx == -1 { + return errors.New("invalid SCRAM server-first-message received from server: did not include i=") + } + saltStr := buf[:idx] + buf = buf[idx+1:] + + if !bytes.HasPrefix(buf, []byte("i=")) { + return errors.New("invalid SCRAM server-first-message received from server: did not include i=") + } + buf = buf[2:] + iterationsStr := buf + + var err error + sc.salt, err = base64.StdEncoding.DecodeString(string(saltStr)) + if err != nil { + return fmt.Errorf("invalid SCRAM salt received from server: %w", err) + } + + sc.iterations, err = strconv.Atoi(string(iterationsStr)) + if err != nil || sc.iterations <= 0 { + return fmt.Errorf("invalid SCRAM iteration count received from server: %w", err) + } + + if !bytes.HasPrefix(sc.clientAndServerNonce, sc.clientNonce) { + return errors.New("invalid SCRAM nonce: did not start with client nonce") + } + + if len(sc.clientAndServerNonce) <= len(sc.clientNonce) { + return errors.New("invalid SCRAM nonce: did not include server nonce") + } + + return nil +} + +func (sc *scramClient) clientFinalMessage() string { + clientFinalMessageWithoutProof := []byte(fmt.Sprintf("c=biws,r=%s", sc.clientAndServerNonce)) + + sc.saltedPassword = pbkdf2.Key([]byte(sc.password), sc.salt, sc.iterations, 32, sha256.New) + sc.authMessage = bytes.Join([][]byte{sc.clientFirstMessageBare, sc.serverFirstMessage, clientFinalMessageWithoutProof}, []byte(",")) + + clientProof := computeClientProof(sc.saltedPassword, sc.authMessage) + + return fmt.Sprintf("%s,p=%s", clientFinalMessageWithoutProof, clientProof) +} + +func (sc *scramClient) recvServerFinalMessage(serverFinalMessage []byte) error { + if !bytes.HasPrefix(serverFinalMessage, []byte("v=")) { + return errors.New("invalid SCRAM server-final-message received from server") + } + + serverSignature := serverFinalMessage[2:] + + if !hmac.Equal(serverSignature, computeServerSignature(sc.saltedPassword, sc.authMessage)) { + return errors.New("invalid SCRAM ServerSignature received from server") + } + + return nil +} + +func computeHMAC(key, msg []byte) []byte { + mac := hmac.New(sha256.New, key) + mac.Write(msg) + return mac.Sum(nil) +} + +func computeClientProof(saltedPassword, authMessage []byte) []byte { + clientKey := computeHMAC(saltedPassword, []byte("Client Key")) + storedKey := sha256.Sum256(clientKey) + clientSignature := computeHMAC(storedKey[:], authMessage) + + clientProof := make([]byte, len(clientSignature)) + for i := 0; i < len(clientSignature); i++ { + clientProof[i] = clientKey[i] ^ clientSignature[i] + } + + buf := make([]byte, base64.StdEncoding.EncodedLen(len(clientProof))) + base64.StdEncoding.Encode(buf, clientProof) + return buf +} + +func computeServerSignature(saltedPassword []byte, authMessage []byte) []byte { + serverKey := computeHMAC(saltedPassword, []byte("Server Key")) + serverSignature := computeHMAC(serverKey, authMessage) + buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature))) + base64.StdEncoding.Encode(buf, serverSignature) + return buf +} diff --git a/pgconn/benchmark_private_test.go b/pgconn/benchmark_private_test.go new file mode 100644 index 00000000..9ea036ec --- /dev/null +++ b/pgconn/benchmark_private_test.go @@ -0,0 +1,73 @@ +package pgconn + +import ( + "strings" + "testing" +) + +func BenchmarkCommandTagRowsAffected(b *testing.B) { + benchmarks := []struct { + commandTag string + rowsAffected int64 + }{ + {"UPDATE 1", 1}, + {"UPDATE 123456789", 123456789}, + {"INSERT 0 1", 1}, + {"INSERT 0 123456789", 123456789}, + } + + for _, bm := range benchmarks { + ct := CommandTag{s: bm.commandTag} + b.Run(bm.commandTag, func(b *testing.B) { + var n int64 + for i := 0; i < b.N; i++ { + n = ct.RowsAffected() + } + if n != bm.rowsAffected { + b.Errorf("expected %d got %d", bm.rowsAffected, n) + } + }) + } +} + +func BenchmarkCommandTagTypeFromString(b *testing.B) { + ct := CommandTag{s: "UPDATE 1"} + + var update bool + for i := 0; i < b.N; i++ { + update = strings.HasPrefix(ct.String(), "UPDATE") + } + if !update { + b.Error("expected update") + } +} + +func BenchmarkCommandTagInsert(b *testing.B) { + benchmarks := []struct { + commandTag string + is bool + }{ + {"INSERT 1", true}, + {"INSERT 1234567890", true}, + {"UPDATE 1", false}, + {"UPDATE 1234567890", false}, + {"DELETE 1", false}, + {"DELETE 1234567890", false}, + {"SELECT 1", false}, + {"SELECT 1234567890", false}, + {"UNKNOWN 1234567890", false}, + } + + for _, bm := range benchmarks { + ct := CommandTag{s: bm.commandTag} + b.Run(bm.commandTag, func(b *testing.B) { + var is bool + for i := 0; i < b.N; i++ { + is = ct.Insert() + } + if is != bm.is { + b.Errorf("expected %v got %v", bm.is, is) + } + }) + } +} diff --git a/pgconn/benchmark_test.go b/pgconn/benchmark_test.go new file mode 100644 index 00000000..ffa42243 --- /dev/null +++ b/pgconn/benchmark_test.go @@ -0,0 +1,254 @@ +package pgconn_test + +import ( + "bytes" + "context" + "os" + "testing" + + "github.com/jackc/pgx/v5/pgconn" + "github.com/stretchr/testify/require" +) + +func BenchmarkConnect(b *testing.B) { + benchmarks := []struct { + name string + env string + }{ + {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, + {"TCP", "PGX_TEST_TCP_CONN_STRING"}, + } + + for _, bm := range benchmarks { + bm := bm + b.Run(bm.name, func(b *testing.B) { + connString := os.Getenv(bm.env) + if connString == "" { + b.Skipf("Skipping due to missing environment variable %v", bm.env) + } + + for i := 0; i < b.N; i++ { + conn, err := pgconn.Connect(context.Background(), connString) + require.Nil(b, err) + + err = conn.Close(context.Background()) + require.Nil(b, err) + } + }) + } +} + +func BenchmarkExec(b *testing.B) { + expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")} + benchmarks := []struct { + name string + ctx context.Context + }{ + // Using an empty context other than context.Background() to compare + // performance + {"background context", context.Background()}, + {"empty context", context.TODO()}, + } + + for _, bm := range benchmarks { + bm := bm + b.Run(bm.name, func(b *testing.B) { + conn, err := pgconn.Connect(bm.ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.Nil(b, err) + defer closeConn(b, conn) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + mrr := conn.Exec(bm.ctx, "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date") + + for mrr.NextResult() { + rr := mrr.ResultReader() + + rowCount := 0 + for rr.NextRow() { + rowCount++ + if len(rr.Values()) != len(expectedValues) { + b.Fatalf("unexpected number of values: %d", len(rr.Values())) + } + for i := range rr.Values() { + if !bytes.Equal(rr.Values()[i], expectedValues[i]) { + b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) + } + } + } + _, err = rr.Close() + + if err != nil { + b.Fatal(err) + } + if rowCount != 1 { + b.Fatalf("unexpected rowCount: %d", rowCount) + } + } + + err := mrr.Close() + if err != nil { + b.Fatal(err) + } + } + }) + } +} + +func BenchmarkExecPossibleToCancel(b *testing.B) { + conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.Nil(b, err) + defer closeConn(b, conn) + + expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")} + + b.ResetTimer() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + for i := 0; i < b.N; i++ { + mrr := conn.Exec(ctx, "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date") + + for mrr.NextResult() { + rr := mrr.ResultReader() + + rowCount := 0 + for rr.NextRow() { + rowCount++ + if len(rr.Values()) != len(expectedValues) { + b.Fatalf("unexpected number of values: %d", len(rr.Values())) + } + for i := range rr.Values() { + if !bytes.Equal(rr.Values()[i], expectedValues[i]) { + b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) + } + } + } + _, err = rr.Close() + + if err != nil { + b.Fatal(err) + } + if rowCount != 1 { + b.Fatalf("unexpected rowCount: %d", rowCount) + } + } + + err := mrr.Close() + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkExecPrepared(b *testing.B) { + expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")} + + benchmarks := []struct { + name string + ctx context.Context + }{ + // Using an empty context other than context.Background() to compare + // performance + {"background context", context.Background()}, + {"empty context", context.TODO()}, + } + + for _, bm := range benchmarks { + bm := bm + b.Run(bm.name, func(b *testing.B) { + conn, err := pgconn.Connect(bm.ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.Nil(b, err) + defer closeConn(b, conn) + + _, err = conn.Prepare(bm.ctx, "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) + require.Nil(b, err) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + rr := conn.ExecPrepared(bm.ctx, "ps1", nil, nil, nil) + + rowCount := 0 + for rr.NextRow() { + rowCount++ + if len(rr.Values()) != len(expectedValues) { + b.Fatalf("unexpected number of values: %d", len(rr.Values())) + } + for i := range rr.Values() { + if !bytes.Equal(rr.Values()[i], expectedValues[i]) { + b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) + } + } + } + _, err = rr.Close() + + if err != nil { + b.Fatal(err) + } + if rowCount != 1 { + b.Fatalf("unexpected rowCount: %d", rowCount) + } + } + }) + } +} + +func BenchmarkExecPreparedPossibleToCancel(b *testing.B) { + conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.Nil(b, err) + defer closeConn(b, conn) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + _, err = conn.Prepare(ctx, "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) + require.Nil(b, err) + + expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")} + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + rr := conn.ExecPrepared(ctx, "ps1", nil, nil, nil) + + rowCount := 0 + for rr.NextRow() { + rowCount += 1 + if len(rr.Values()) != len(expectedValues) { + b.Fatalf("unexpected number of values: %d", len(rr.Values())) + } + for i := range rr.Values() { + if !bytes.Equal(rr.Values()[i], expectedValues[i]) { + b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) + } + } + } + _, err = rr.Close() + + if err != nil { + b.Fatal(err) + } + if rowCount != 1 { + b.Fatalf("unexpected rowCount: %d", rowCount) + } + } +} + +// func BenchmarkChanToSetDeadlinePossibleToCancel(b *testing.B) { +// conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) +// require.Nil(b, err) +// defer closeConn(b, conn) + +// ctx, cancel := context.WithCancel(context.Background()) +// defer cancel() + +// b.ResetTimer() + +// for i := 0; i < b.N; i++ { +// conn.ChanToSetDeadline().Watch(ctx) +// conn.ChanToSetDeadline().Ignore() +// } +// } diff --git a/pgconn/config.go b/pgconn/config.go new file mode 100644 index 00000000..cff4bed0 --- /dev/null +++ b/pgconn/config.go @@ -0,0 +1,886 @@ +package pgconn + +import ( + "context" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + "io" + "io/ioutil" + "math" + "net" + "net/url" + "os" + "path/filepath" + "strconv" + "strings" + "time" + + "github.com/jackc/pgpassfile" + "github.com/jackc/pgservicefile" + "github.com/jackc/pgx/v5/pgproto3" +) + +type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error +type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error +type GetSSLPasswordFunc func(ctx context.Context) string + +// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig. A +// manually initialized Config will cause ConnectConfig to panic. +type Config struct { + Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp) + Port uint16 + Database string + User string + Password string + TLSConfig *tls.Config // nil disables TLS + ConnectTimeout time.Duration + DialFunc DialFunc // e.g. net.Dialer.DialContext + LookupFunc LookupFunc // e.g. net.Resolver.LookupHost + BuildFrontend BuildFrontendFunc + RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) + + KerberosSrvName string + KerberosSpn string + Fallbacks []*FallbackConfig + + // ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server. + // It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next + // fallback config is tried. This allows implementing high availability behavior such as libpq does with target_session_attrs. + ValidateConnect ValidateConnectFunc + + // AfterConnect is called after ValidateConnect. It can be used to set up the connection (e.g. Set session variables + // or prepare statements). If this returns an error the connection attempt fails. + AfterConnect AfterConnectFunc + + // OnNotice is a callback function called when a notice response is received. + OnNotice NoticeHandler + + // OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received. + OnNotification NotificationHandler + + createdByParseConfig bool // Used to enforce created by ParseConfig rule. +} + +// ParseConfigOptions contains options that control how a config is built such as getsslpassword. +type ParseConfigOptions struct { + // GetSSLPassword gets the password to decrypt a SSL client certificate. This is analogous to the the libpq function + // PQsetSSLKeyPassHook_OpenSSL. + GetSSLPassword GetSSLPasswordFunc +} + +// Copy returns a deep copy of the config that is safe to use and modify. +// The only exception is the TLSConfig field: +// according to the tls.Config docs it must not be modified after creation. +func (c *Config) Copy() *Config { + newConf := new(Config) + *newConf = *c + if newConf.TLSConfig != nil { + newConf.TLSConfig = c.TLSConfig.Clone() + } + if newConf.RuntimeParams != nil { + newConf.RuntimeParams = make(map[string]string, len(c.RuntimeParams)) + for k, v := range c.RuntimeParams { + newConf.RuntimeParams[k] = v + } + } + if newConf.Fallbacks != nil { + newConf.Fallbacks = make([]*FallbackConfig, len(c.Fallbacks)) + for i, fallback := range c.Fallbacks { + newFallback := new(FallbackConfig) + *newFallback = *fallback + if newFallback.TLSConfig != nil { + newFallback.TLSConfig = fallback.TLSConfig.Clone() + } + newConf.Fallbacks[i] = newFallback + } + } + return newConf +} + +// FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a +// network connection. It is used for TLS fallback such as sslmode=prefer and high availability (HA) connections. +type FallbackConfig struct { + Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) + Port uint16 + TLSConfig *tls.Config // nil disables TLS +} + +// isAbsolutePath checks if the provided value is an absolute path either +// beginning with a forward slash (as on Linux-based systems) or with a capital +// letter A-Z followed by a colon and a backslash, e.g., "C:\", (as on Windows). +func isAbsolutePath(path string) bool { + isWindowsPath := func(p string) bool { + if len(p) < 3 { + return false + } + drive := p[0] + colon := p[1] + backslash := p[2] + if drive >= 'A' && drive <= 'Z' && colon == ':' && backslash == '\\' { + return true + } + return false + } + return strings.HasPrefix(path, "/") || isWindowsPath(path) +} + +// NetworkAddress converts a PostgreSQL host and port into network and address suitable for use with +// net.Dial. +func NetworkAddress(host string, port uint16) (network, address string) { + if isAbsolutePath(host) { + network = "unix" + address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10) + } else { + network = "tcp" + address = net.JoinHostPort(host, strconv.Itoa(int(port))) + } + return network, address +} + +// ParseConfig builds a *Config from connString with similar behavior to the PostgreSQL standard C library libpq. It +// uses the same defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely +// matches the parsing behavior of libpq. connString may either be in URL format or keyword = value format (DSN style). +// See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be +// empty to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file. +// +// # Example DSN +// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca +// +// # Example URL +// postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca +// +// The returned *Config may be modified. However, it is strongly recommended that any configuration that can be done +// through the connection string be done there. In particular the fields Host, Port, TLSConfig, and Fallbacks can be +// interdependent (e.g. TLSConfig needs knowledge of the host to validate the server certificate). These fields should +// not be modified individually. They should all be modified or all left unchanged. +// +// ParseConfig supports specifying multiple hosts in similar manner to libpq. Host and port may include comma separated +// values that will be tried in order. This can be used as part of a high availability system. See +// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information. +// +// # Example URL +// postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb +// +// ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed +// via database URL or DSN: +// +// PGHOST +// PGPORT +// PGDATABASE +// PGUSER +// PGPASSWORD +// PGPASSFILE +// PGSERVICE +// PGSERVICEFILE +// PGSSLMODE +// PGSSLCERT +// PGSSLKEY +// PGSSLROOTCERT +// PGSSLPASSWORD +// PGAPPNAME +// PGCONNECT_TIMEOUT +// PGTARGETSESSIONATTRS +// +// See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables. +// +// See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key word names. They are +// usually but not always the environment variable name downcased and without the "PG" prefix. +// +// Important Security Notes: +// +// ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to "prefer" behavior if +// not set. +// +// See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on what level of +// security each sslmode provides. +// +// The sslmode "prefer" (the default), sslmode "allow", and multiple hosts are implemented via the Fallbacks field of +// the Config struct. If TLSConfig is manually changed it will not affect the fallbacks. For example, in the case of +// sslmode "prefer" this means it will first try the main Config settings which use TLS, then it will try the fallback +// which does not use TLS. This can lead to an unexpected unencrypted connection if the main TLS config is manually +// changed later but the unencrypted fallback is present. Ensure there are no stale fallbacks when manually setting +// TLSConfig. +// +// Other known differences with libpq: +// +// When multiple hosts are specified, libpq allows them to have different passwords set via the .pgpass file. pgconn +// does not. +// +// In addition, ParseConfig accepts the following options: +// +// servicefile +// libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a +// part of the connection string. +func ParseConfig(connString string) (*Config, error) { + var parseConfigOptions ParseConfigOptions + return ParseConfigWithOptions(connString, parseConfigOptions) +} + +// ParseConfigWithOptions builds a *Config from connString and options with similar behavior to the PostgreSQL standard +// C library libpq. options contains settings that cannot be specified in a connString such as providing a function to +// get the SSL password. +func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Config, error) { + defaultSettings := defaultSettings() + envSettings := parseEnvSettings() + + connStringSettings := make(map[string]string) + if connString != "" { + var err error + // connString may be a database URL or a DSN + if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") { + connStringSettings, err = parseURLSettings(connString) + if err != nil { + return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err} + } + } else { + connStringSettings, err = parseDSNSettings(connString) + if err != nil { + return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err} + } + } + } + + settings := mergeSettings(defaultSettings, envSettings, connStringSettings) + if service, present := settings["service"]; present { + serviceSettings, err := parseServiceSettings(settings["servicefile"], service) + if err != nil { + return nil, &parseConfigError{connString: connString, msg: "failed to read service", err: err} + } + + settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings) + } + + config := &Config{ + createdByParseConfig: true, + Database: settings["database"], + User: settings["user"], + Password: settings["password"], + RuntimeParams: make(map[string]string), + BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend { + return pgproto3.NewFrontend(r, w) + }, + } + + if connectTimeoutSetting, present := settings["connect_timeout"]; present { + connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting) + if err != nil { + return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err} + } + config.ConnectTimeout = connectTimeout + config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout) + } else { + defaultDialer := makeDefaultDialer() + config.DialFunc = defaultDialer.DialContext + } + + config.LookupFunc = makeDefaultResolver().LookupHost + + notRuntimeParams := map[string]struct{}{ + "host": {}, + "port": {}, + "database": {}, + "user": {}, + "password": {}, + "passfile": {}, + "connect_timeout": {}, + "sslmode": {}, + "sslkey": {}, + "sslcert": {}, + "sslrootcert": {}, + "sslpassword": {}, + "sslsni": {}, + "krbspn": {}, + "krbsrvname": {}, + "target_session_attrs": {}, + "service": {}, + "servicefile": {}, + } + + // Adding kerberos configuration + if _, present := settings["krbsrvname"]; present { + config.KerberosSrvName = settings["krbsrvname"] + } + if _, present := settings["krbspn"]; present { + config.KerberosSpn = settings["krbspn"] + } + + for k, v := range settings { + if _, present := notRuntimeParams[k]; present { + continue + } + config.RuntimeParams[k] = v + } + + fallbacks := []*FallbackConfig{} + + hosts := strings.Split(settings["host"], ",") + ports := strings.Split(settings["port"], ",") + + for i, host := range hosts { + var portStr string + if i < len(ports) { + portStr = ports[i] + } else { + portStr = ports[0] + } + + port, err := parsePort(portStr) + if err != nil { + return nil, &parseConfigError{connString: connString, msg: "invalid port", err: err} + } + + var tlsConfigs []*tls.Config + + // Ignore TLS settings if Unix domain socket like libpq + if network, _ := NetworkAddress(host, port); network == "unix" { + tlsConfigs = append(tlsConfigs, nil) + } else { + var err error + tlsConfigs, err = configTLS(settings, host, options) + if err != nil { + return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err} + } + } + + for _, tlsConfig := range tlsConfigs { + fallbacks = append(fallbacks, &FallbackConfig{ + Host: host, + Port: port, + TLSConfig: tlsConfig, + }) + } + } + + config.Host = fallbacks[0].Host + config.Port = fallbacks[0].Port + config.TLSConfig = fallbacks[0].TLSConfig + config.Fallbacks = fallbacks[1:] + + passfile, err := pgpassfile.ReadPassfile(settings["passfile"]) + if err == nil { + if config.Password == "" { + host := config.Host + if network, _ := NetworkAddress(config.Host, config.Port); network == "unix" { + host = "localhost" + } + + config.Password = passfile.FindPassword(host, strconv.Itoa(int(config.Port)), config.Database, config.User) + } + } + + switch tsa := settings["target_session_attrs"]; tsa { + case "read-write": + config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite + case "read-only": + config.ValidateConnect = ValidateConnectTargetSessionAttrsReadOnly + case "primary": + config.ValidateConnect = ValidateConnectTargetSessionAttrsPrimary + case "standby": + config.ValidateConnect = ValidateConnectTargetSessionAttrsStandby + case "prefer-standby": + config.ValidateConnect = ValidateConnectTargetSessionAttrsPreferStandby + case "any": + // do nothing + default: + return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)} + } + + return config, nil +} + +func mergeSettings(settingSets ...map[string]string) map[string]string { + settings := make(map[string]string) + + for _, s2 := range settingSets { + for k, v := range s2 { + settings[k] = v + } + } + + return settings +} + +func parseEnvSettings() map[string]string { + settings := make(map[string]string) + + nameMap := map[string]string{ + "PGHOST": "host", + "PGPORT": "port", + "PGDATABASE": "database", + "PGUSER": "user", + "PGPASSWORD": "password", + "PGPASSFILE": "passfile", + "PGAPPNAME": "application_name", + "PGCONNECT_TIMEOUT": "connect_timeout", + "PGSSLMODE": "sslmode", + "PGSSLKEY": "sslkey", + "PGSSLCERT": "sslcert", + "PGSSLSNI": "sslsni", + "PGSSLROOTCERT": "sslrootcert", + "PGSSLPASSWORD": "sslpassword", + "PGTARGETSESSIONATTRS": "target_session_attrs", + "PGSERVICE": "service", + "PGSERVICEFILE": "servicefile", + } + + for envname, realname := range nameMap { + value := os.Getenv(envname) + if value != "" { + settings[realname] = value + } + } + + return settings +} + +func parseURLSettings(connString string) (map[string]string, error) { + settings := make(map[string]string) + + url, err := url.Parse(connString) + if err != nil { + return nil, err + } + + if url.User != nil { + settings["user"] = url.User.Username() + if password, present := url.User.Password(); present { + settings["password"] = password + } + } + + // Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port. + var hosts []string + var ports []string + for _, host := range strings.Split(url.Host, ",") { + if host == "" { + continue + } + if isIPOnly(host) { + hosts = append(hosts, strings.Trim(host, "[]")) + continue + } + h, p, err := net.SplitHostPort(host) + if err != nil { + return nil, fmt.Errorf("failed to split host:port in '%s', err: %w", host, err) + } + if h != "" { + hosts = append(hosts, h) + } + if p != "" { + ports = append(ports, p) + } + } + if len(hosts) > 0 { + settings["host"] = strings.Join(hosts, ",") + } + if len(ports) > 0 { + settings["port"] = strings.Join(ports, ",") + } + + database := strings.TrimLeft(url.Path, "/") + if database != "" { + settings["database"] = database + } + + nameMap := map[string]string{ + "dbname": "database", + } + + for k, v := range url.Query() { + if k2, present := nameMap[k]; present { + k = k2 + } + + settings[k] = v[0] + } + + return settings, nil +} + +func isIPOnly(host string) bool { + return net.ParseIP(strings.Trim(host, "[]")) != nil || !strings.Contains(host, ":") +} + +var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1} + +func parseDSNSettings(s string) (map[string]string, error) { + settings := make(map[string]string) + + nameMap := map[string]string{ + "dbname": "database", + } + + for len(s) > 0 { + var key, val string + eqIdx := strings.IndexRune(s, '=') + if eqIdx < 0 { + return nil, errors.New("invalid dsn") + } + + key = strings.Trim(s[:eqIdx], " \t\n\r\v\f") + s = strings.TrimLeft(s[eqIdx+1:], " \t\n\r\v\f") + if len(s) == 0 { + } else if s[0] != '\'' { + end := 0 + for ; end < len(s); end++ { + if asciiSpace[s[end]] == 1 { + break + } + if s[end] == '\\' { + end++ + if end == len(s) { + return nil, errors.New("invalid backslash") + } + } + } + val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1) + if end == len(s) { + s = "" + } else { + s = s[end+1:] + } + } else { // quoted string + s = s[1:] + end := 0 + for ; end < len(s); end++ { + if s[end] == '\'' { + break + } + if s[end] == '\\' { + end++ + } + } + if end == len(s) { + return nil, errors.New("unterminated quoted string in connection info string") + } + val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1) + if end == len(s) { + s = "" + } else { + s = s[end+1:] + } + } + + if k, ok := nameMap[key]; ok { + key = k + } + + if key == "" { + return nil, errors.New("invalid dsn") + } + + settings[key] = val + } + + return settings, nil +} + +func parseServiceSettings(servicefilePath, serviceName string) (map[string]string, error) { + servicefile, err := pgservicefile.ReadServicefile(servicefilePath) + if err != nil { + return nil, fmt.Errorf("failed to read service file: %v", servicefilePath) + } + + service, err := servicefile.GetService(serviceName) + if err != nil { + return nil, fmt.Errorf("unable to find service: %v", serviceName) + } + + nameMap := map[string]string{ + "dbname": "database", + } + + settings := make(map[string]string, len(service.Settings)) + for k, v := range service.Settings { + if k2, present := nameMap[k]; present { + k = k2 + } + settings[k] = v + } + + return settings, nil +} + +// configTLS uses libpq's TLS parameters to construct []*tls.Config. It is +// necessary to allow returning multiple TLS configs as sslmode "allow" and +// "prefer" allow fallback. +func configTLS(settings map[string]string, thisHost string, parseConfigOptions ParseConfigOptions) ([]*tls.Config, error) { + host := thisHost + sslmode := settings["sslmode"] + sslrootcert := settings["sslrootcert"] + sslcert := settings["sslcert"] + sslkey := settings["sslkey"] + sslpassword := settings["sslpassword"] + sslsni := settings["sslsni"] + + // Match libpq default behavior + if sslmode == "" { + sslmode = "prefer" + } + if sslsni == "" { + sslsni = "1" + } + + tlsConfig := &tls.Config{} + + switch sslmode { + case "disable": + return []*tls.Config{nil}, nil + case "allow", "prefer": + tlsConfig.InsecureSkipVerify = true + case "require": + // According to PostgreSQL documentation, if a root CA file exists, + // the behavior of sslmode=require should be the same as that of verify-ca + // + // See https://www.postgresql.org/docs/12/libpq-ssl.html + if sslrootcert != "" { + goto nextCase + } + tlsConfig.InsecureSkipVerify = true + break + nextCase: + fallthrough + case "verify-ca": + // Don't perform the default certificate verification because it + // will verify the hostname. Instead, verify the server's + // certificate chain ourselves in VerifyPeerCertificate and + // ignore the server name. This emulates libpq's verify-ca + // behavior. + // + // See https://github.com/golang/go/issues/21971#issuecomment-332693931 + // and https://pkg.go.dev/crypto/tls?tab=doc#example-Config-VerifyPeerCertificate + // for more info. + tlsConfig.InsecureSkipVerify = true + tlsConfig.VerifyPeerCertificate = func(certificates [][]byte, _ [][]*x509.Certificate) error { + certs := make([]*x509.Certificate, len(certificates)) + for i, asn1Data := range certificates { + cert, err := x509.ParseCertificate(asn1Data) + if err != nil { + return errors.New("failed to parse certificate from server: " + err.Error()) + } + certs[i] = cert + } + + // Leave DNSName empty to skip hostname verification. + opts := x509.VerifyOptions{ + Roots: tlsConfig.RootCAs, + Intermediates: x509.NewCertPool(), + } + // Skip the first cert because it's the leaf. All others + // are intermediates. + for _, cert := range certs[1:] { + opts.Intermediates.AddCert(cert) + } + _, err := certs[0].Verify(opts) + return err + } + case "verify-full": + tlsConfig.ServerName = host + default: + return nil, errors.New("sslmode is invalid") + } + + if sslrootcert != "" { + caCertPool := x509.NewCertPool() + + caPath := sslrootcert + caCert, err := ioutil.ReadFile(caPath) + if err != nil { + return nil, fmt.Errorf("unable to read CA file: %w", err) + } + + if !caCertPool.AppendCertsFromPEM(caCert) { + return nil, errors.New("unable to add CA to cert pool") + } + + tlsConfig.RootCAs = caCertPool + tlsConfig.ClientCAs = caCertPool + } + + if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") { + return nil, errors.New(`both "sslcert" and "sslkey" are required`) + } + + if sslcert != "" && sslkey != "" { + buf, err := ioutil.ReadFile(sslkey) + if err != nil { + return nil, fmt.Errorf("unable to read sslkey: %w", err) + } + block, _ := pem.Decode(buf) + var pemKey []byte + var decryptedKey []byte + var decryptedError error + // If PEM is encrypted, attempt to decrypt using pass phrase + if x509.IsEncryptedPEMBlock(block) { + // Attempt decryption with pass phrase + // NOTE: only supports RSA (PKCS#1) + if sslpassword != "" { + decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword)) + } + //if sslpassword not provided or has decryption error when use it + //try to find sslpassword with callback function + if sslpassword == "" || decryptedError != nil { + if parseConfigOptions.GetSSLPassword != nil { + sslpassword = parseConfigOptions.GetSSLPassword(context.Background()) + } + if sslpassword == "" { + return nil, fmt.Errorf("unable to find sslpassword") + } + } + decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword)) + // Should we also provide warning for PKCS#1 needed? + if decryptedError != nil { + return nil, fmt.Errorf("unable to decrypt key: %w", err) + } + + pemBytes := pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: decryptedKey, + } + pemKey = pem.EncodeToMemory(&pemBytes) + } else { + pemKey = pem.EncodeToMemory(block) + } + certfile, err := ioutil.ReadFile(sslcert) + if err != nil { + return nil, fmt.Errorf("unable to read cert: %w", err) + } + cert, err := tls.X509KeyPair(certfile, pemKey) + if err != nil { + return nil, fmt.Errorf("unable to load cert: %w", err) + } + tlsConfig.Certificates = []tls.Certificate{cert} + } + + // Set Server Name Indication (SNI), if enabled by connection parameters. + // Per RFC 6066, do not set it if the host is a literal IP address (IPv4 + // or IPv6). + if sslsni == "1" && net.ParseIP(host) == nil { + tlsConfig.ServerName = host + } + + switch sslmode { + case "allow": + return []*tls.Config{nil, tlsConfig}, nil + case "prefer": + return []*tls.Config{tlsConfig, nil}, nil + case "require", "verify-ca", "verify-full": + return []*tls.Config{tlsConfig}, nil + default: + panic("BUG: bad sslmode should already have been caught") + } +} + +func parsePort(s string) (uint16, error) { + port, err := strconv.ParseUint(s, 10, 16) + if err != nil { + return 0, err + } + if port < 1 || port > math.MaxUint16 { + return 0, errors.New("outside range") + } + return uint16(port), nil +} + +func makeDefaultDialer() *net.Dialer { + return &net.Dialer{KeepAlive: 5 * time.Minute} +} + +func makeDefaultResolver() *net.Resolver { + return net.DefaultResolver +} + +func parseConnectTimeoutSetting(s string) (time.Duration, error) { + timeout, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return 0, err + } + if timeout < 0 { + return 0, errors.New("negative timeout") + } + return time.Duration(timeout) * time.Second, nil +} + +func makeConnectTimeoutDialFunc(timeout time.Duration) DialFunc { + d := makeDefaultDialer() + d.Timeout = timeout + return d.DialContext +} + +// ValidateConnectTargetSessionAttrsReadWrite is an ValidateConnectFunc that implements libpq compatible +// target_session_attrs=read-write. +func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error { + result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read() + if result.Err != nil { + return result.Err + } + + if string(result.Rows[0][0]) == "on" { + return errors.New("read only connection") + } + + return nil +} + +// ValidateConnectTargetSessionAttrsReadOnly is an ValidateConnectFunc that implements libpq compatible +// target_session_attrs=read-only. +func ValidateConnectTargetSessionAttrsReadOnly(ctx context.Context, pgConn *PgConn) error { + result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read() + if result.Err != nil { + return result.Err + } + + if string(result.Rows[0][0]) != "on" { + return errors.New("connection is not read only") + } + + return nil +} + +// ValidateConnectTargetSessionAttrsStandby is an ValidateConnectFunc that implements libpq compatible +// target_session_attrs=standby. +func ValidateConnectTargetSessionAttrsStandby(ctx context.Context, pgConn *PgConn) error { + result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read() + if result.Err != nil { + return result.Err + } + + if string(result.Rows[0][0]) != "t" { + return errors.New("server is not in hot standby mode") + } + + return nil +} + +// ValidateConnectTargetSessionAttrsPrimary is an ValidateConnectFunc that implements libpq compatible +// target_session_attrs=primary. +func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgConn) error { + result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read() + if result.Err != nil { + return result.Err + } + + if string(result.Rows[0][0]) == "t" { + return errors.New("server is in standby mode") + } + + return nil +} + +// ValidateConnectTargetSessionAttrsPreferStandby is an ValidateConnectFunc that implements libpq compatible +// target_session_attrs=prefer-standby. +func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn *PgConn) error { + result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read() + if result.Err != nil { + return result.Err + } + + if string(result.Rows[0][0]) != "t" { + return &NotPreferredError{err: errors.New("server is not in hot standby mode")} + } + + return nil +} diff --git a/pgconn/config_test.go b/pgconn/config_test.go new file mode 100644 index 00000000..94ca2729 --- /dev/null +++ b/pgconn/config_test.go @@ -0,0 +1,1134 @@ +package pgconn_test + +import ( + "context" + "crypto/tls" + "fmt" + "io/ioutil" + "os" + "os/user" + "runtime" + "strings" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgconn" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseConfig(t *testing.T) { + t.Parallel() + + var osUserName string + osUser, err := user.Current() + if err == nil { + // Windows gives us the username here as `DOMAIN\user` or `LOCALPCNAME\user`, + // but the libpq default is just the `user` portion, so we strip off the first part. + if runtime.GOOS == "windows" && strings.Contains(osUser.Username, "\\") { + osUserName = osUser.Username[strings.LastIndex(osUser.Username, "\\")+1:] + } else { + osUserName = osUser.Username + } + } + + config, err := pgconn.ParseConfig("") + require.NoError(t, err) + defaultHost := config.Host + + tests := []struct { + name string + connString string + config *pgconn.Config + }{ + // Test all sslmodes + { + name: "sslmode not set (prefer)", + connString: "postgres://jack:secret@localhost:5432/mydb", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + ServerName: "localhost", + }, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "localhost", + Port: 5432, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "sslmode disable", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "sslmode allow", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=allow", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "localhost", + Port: 5432, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + ServerName: "localhost", + }, + }, + }, + }, + }, + { + name: "sslmode prefer", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=prefer", + config: &pgconn.Config{ + + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + ServerName: "localhost", + }, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "localhost", + Port: 5432, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "sslmode require", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=require", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + ServerName: "localhost", + }, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "sslmode verify-ca", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=verify-ca", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + ServerName: "localhost", + }, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "sslmode verify-full", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=verify-full", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ServerName: "localhost"}, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url everything", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&application_name=pgxtest&search_path=myschema&connect_timeout=5", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + ConnectTimeout: 5 * time.Second, + RuntimeParams: map[string]string{ + "application_name": "pgxtest", + "search_path": "myschema", + }, + }, + }, + { + name: "database url missing password", + connString: "postgres://jack@localhost:5432/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url missing user and password", + connString: "postgres://localhost:5432/mydb?sslmode=disable", + config: &pgconn.Config{ + User: osUserName, + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url missing port", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url unix domain socket host", + connString: "postgres:///foo?host=/tmp", + config: &pgconn.Config{ + User: osUserName, + Host: "/tmp", + Port: 5432, + Database: "foo", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url unix domain socket host on windows", + connString: "postgres:///foo?host=C:\\tmp", + config: &pgconn.Config{ + User: osUserName, + Host: "C:\\tmp", + Port: 5432, + Database: "foo", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url dbname", + connString: "postgres://localhost/?dbname=foo&sslmode=disable", + config: &pgconn.Config{ + User: osUserName, + Host: "localhost", + Port: 5432, + Database: "foo", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url postgresql protocol", + connString: "postgresql://jack@localhost:5432/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url IPv4 with port", + connString: "postgresql://jack@127.0.0.1:5433/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Host: "127.0.0.1", + Port: 5433, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url IPv6 with port", + connString: "postgresql://jack@[2001:db8::1]:5433/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Host: "2001:db8::1", + Port: 5433, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url IPv6 no port", + connString: "postgresql://jack@[2001:db8::1]/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Host: "2001:db8::1", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "DSN everything", + connString: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable application_name=pgxtest search_path=myschema connect_timeout=5", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + ConnectTimeout: 5 * time.Second, + RuntimeParams: map[string]string{ + "application_name": "pgxtest", + "search_path": "myschema", + }, + }, + }, + { + name: "DSN with escaped single quote", + connString: "user=jack\\'s password=secret host=localhost port=5432 dbname=mydb sslmode=disable", + config: &pgconn.Config{ + User: "jack's", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "DSN with escaped backslash", + connString: "user=jack password=sooper\\\\secret host=localhost port=5432 dbname=mydb sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "sooper\\secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "DSN with single quoted values", + connString: "user='jack' host='localhost' dbname='mydb' sslmode='disable'", + config: &pgconn.Config{ + User: "jack", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "DSN with single quoted value with escaped single quote", + connString: "user='jack\\'s' host='localhost' dbname='mydb' sslmode='disable'", + config: &pgconn.Config{ + User: "jack's", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "DSN with empty single quoted value", + connString: "user='jack' password='' host='localhost' dbname='mydb' sslmode='disable'", + config: &pgconn.Config{ + User: "jack", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "DSN with space between key and value", + connString: "user = 'jack' password = '' host = 'localhost' dbname = 'mydb' sslmode='disable'", + config: &pgconn.Config{ + User: "jack", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "URL multiple hosts", + connString: "postgres://jack:secret@foo,bar,baz/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "foo", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "bar", + Port: 5432, + TLSConfig: nil, + }, + &pgconn.FallbackConfig{ + Host: "baz", + Port: 5432, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "URL multiple hosts and ports", + connString: "postgres://jack:secret@foo:1,bar:2,baz:3/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "foo", + Port: 1, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "bar", + Port: 2, + TLSConfig: nil, + }, + &pgconn.FallbackConfig{ + Host: "baz", + Port: 3, + TLSConfig: nil, + }, + }, + }, + }, + // https://github.com/jackc/pgconn/issues/72 + { + name: "URL without host but with port still uses default host", + connString: "postgres://jack:secret@:1/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: defaultHost, + Port: 1, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "DSN multiple hosts one port", + connString: "user=jack password=secret host=foo,bar,baz port=5432 dbname=mydb sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "foo", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "bar", + Port: 5432, + TLSConfig: nil, + }, + &pgconn.FallbackConfig{ + Host: "baz", + Port: 5432, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "DSN multiple hosts multiple ports", + connString: "user=jack password=secret host=foo,bar,baz port=1,2,3 dbname=mydb sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "foo", + Port: 1, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "bar", + Port: 2, + TLSConfig: nil, + }, + &pgconn.FallbackConfig{ + Host: "baz", + Port: 3, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "multiple hosts and fallback tsl", + connString: "user=jack password=secret host=foo,bar,baz dbname=mydb sslmode=prefer", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "foo", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + ServerName: "foo", + }, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "foo", + Port: 5432, + TLSConfig: nil, + }, + &pgconn.FallbackConfig{ + Host: "bar", + Port: 5432, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + ServerName: "bar", + }}, + &pgconn.FallbackConfig{ + Host: "bar", + Port: 5432, + TLSConfig: nil, + }, + &pgconn.FallbackConfig{ + Host: "baz", + Port: 5432, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + ServerName: "baz", + }}, + &pgconn.FallbackConfig{ + Host: "baz", + Port: 5432, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "target_session_attrs read-write", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=read-write", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsReadWrite, + }, + }, + { + name: "target_session_attrs read-only", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=read-only", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsReadOnly, + }, + }, + { + name: "target_session_attrs primary", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=primary", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsPrimary, + }, + }, + { + name: "target_session_attrs standby", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=standby", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsStandby, + }, + }, + { + name: "target_session_attrs prefer-standby", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=prefer-standby", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsPreferStandby, + }, + }, + { + name: "target_session_attrs any", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=any", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "target_session_attrs not set (any)", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "SNI is set by default", + connString: "postgres://jack:secret@sni.test:5432/mydb?sslmode=require", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "sni.test", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + ServerName: "sni.test", + }, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "SNI is not set for IPv4", + connString: "postgres://jack:secret@1.1.1.1:5432/mydb?sslmode=require", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "1.1.1.1", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "SNI is not set for IPv6", + connString: "postgres://jack:secret@[::1]:5432/mydb?sslmode=require", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "::1", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "SNI is not set when disabled (URL-style)", + connString: "postgres://jack:secret@sni.test:5432/mydb?sslmode=require&sslsni=0", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "sni.test", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "SNI is not set when disabled (key/value style)", + connString: "user=jack password=secret host=sni.test dbname=mydb sslmode=require sslsni=0", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "sni.test", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + }, + }, + } + + for i, tt := range tests { + config, err := pgconn.ParseConfig(tt.connString) + if !assert.Nilf(t, err, "Test %d (%s)", i, tt.name) { + continue + } + + assertConfigsEqual(t, tt.config, config, fmt.Sprintf("Test %d (%s)", i, tt.name)) + } +} + +// https://github.com/jackc/pgconn/issues/47 +func TestParseConfigDSNWithTrailingEmptyEqualDoesNotPanic(t *testing.T) { + _, err := pgconn.ParseConfig("host= user= password= port= database=") + require.NoError(t, err) +} + +func TestParseConfigDSNLeadingEqual(t *testing.T) { + _, err := pgconn.ParseConfig("= user=jack") + require.Error(t, err) +} + +// https://github.com/jackc/pgconn/issues/49 +func TestParseConfigDSNTrailingBackslash(t *testing.T) { + _, err := pgconn.ParseConfig(`x=x\`) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid backslash") +} + +func TestConfigCopyReturnsEqualConfig(t *testing.T) { + connString := "postgres://jack:secret@localhost:5432/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5" + original, err := pgconn.ParseConfig(connString) + require.NoError(t, err) + + copied := original.Copy() + assertConfigsEqual(t, original, copied, "Test Config.Copy() returns equal config") +} + +func TestConfigCopyOriginalConfigDidNotChange(t *testing.T) { + connString := "postgres://jack:secret@localhost:5432/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5&sslmode=prefer" + original, err := pgconn.ParseConfig(connString) + require.NoError(t, err) + + copied := original.Copy() + assertConfigsEqual(t, original, copied, "Test Config.Copy() returns equal config") + + copied.Port = uint16(5433) + copied.RuntimeParams["foo"] = "bar" + copied.Fallbacks[0].Port = uint16(5433) + + assert.Equal(t, uint16(5432), original.Port) + assert.Equal(t, "", original.RuntimeParams["foo"]) + assert.Equal(t, uint16(5432), original.Fallbacks[0].Port) +} + +func TestConfigCopyCanBeUsedToConnect(t *testing.T) { + connString := os.Getenv("PGX_TEST_CONN_STRING") + original, err := pgconn.ParseConfig(connString) + require.NoError(t, err) + + copied := original.Copy() + assert.NotPanics(t, func() { + _, err = pgconn.ConnectConfig(context.Background(), copied) + }) + assert.NoError(t, err) +} + +func TestNetworkAddress(t *testing.T) { + tests := []struct { + name string + host string + wantNet string + }{ + { + name: "Default Unix socket address", + host: "/var/run/postgresql", + wantNet: "unix", + }, + { + name: "Windows Unix socket address (standard drive name)", + host: "C:\\tmp", + wantNet: "unix", + }, + { + name: "Windows Unix socket address (first drive name)", + host: "A:\\tmp", + wantNet: "unix", + }, + { + name: "Windows Unix socket address (last drive name)", + host: "Z:\\tmp", + wantNet: "unix", + }, + { + name: "Assume TCP for unknown formats", + host: "a/tmp", + wantNet: "tcp", + }, + { + name: "loopback interface", + host: "localhost", + wantNet: "tcp", + }, + { + name: "IP address", + host: "127.0.0.1", + wantNet: "tcp", + }, + } + for i, tt := range tests { + gotNet, _ := pgconn.NetworkAddress(tt.host, 5432) + + assert.Equalf(t, tt.wantNet, gotNet, "Test %d (%s)", i, tt.name) + } +} + +func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName string) { + if !assert.NotNil(t, expected) { + return + } + if !assert.NotNil(t, actual) { + return + } + + assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName) + assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName) + assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName) + assert.Equalf(t, expected.User, actual.User, "%s - User", testName) + assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName) + assert.Equalf(t, expected.ConnectTimeout, actual.ConnectTimeout, "%s - ConnectTimeout", testName) + assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName) + + // Can't test function equality, so just test that they are set or not. + assert.Equalf(t, expected.ValidateConnect == nil, actual.ValidateConnect == nil, "%s - ValidateConnect", testName) + assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName) + + if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) { + if expected.TLSConfig != nil { + assert.Equalf(t, expected.TLSConfig.InsecureSkipVerify, actual.TLSConfig.InsecureSkipVerify, "%s - TLSConfig InsecureSkipVerify", testName) + assert.Equalf(t, expected.TLSConfig.ServerName, actual.TLSConfig.ServerName, "%s - TLSConfig ServerName", testName) + } + } + + if assert.Equalf(t, len(expected.Fallbacks), len(actual.Fallbacks), "%s - Fallbacks", testName) { + for i := range expected.Fallbacks { + assert.Equalf(t, expected.Fallbacks[i].Host, actual.Fallbacks[i].Host, "%s - Fallback %d - Host", testName, i) + assert.Equalf(t, expected.Fallbacks[i].Port, actual.Fallbacks[i].Port, "%s - Fallback %d - Port", testName, i) + + if assert.Equalf(t, expected.Fallbacks[i].TLSConfig == nil, actual.Fallbacks[i].TLSConfig == nil, "%s - Fallback %d - TLSConfig", testName, i) { + if expected.Fallbacks[i].TLSConfig != nil { + assert.Equalf(t, expected.Fallbacks[i].TLSConfig.InsecureSkipVerify, actual.Fallbacks[i].TLSConfig.InsecureSkipVerify, "%s - Fallback %d - TLSConfig InsecureSkipVerify", testName) + assert.Equalf(t, expected.Fallbacks[i].TLSConfig.ServerName, actual.Fallbacks[i].TLSConfig.ServerName, "%s - Fallback %d - TLSConfig ServerName", testName) + } + } + } + } +} + +func TestParseConfigEnvLibpq(t *testing.T) { + var osUserName string + osUser, err := user.Current() + if err == nil { + // Windows gives us the username here as `DOMAIN\user` or `LOCALPCNAME\user`, + // but the libpq default is just the `user` portion, so we strip off the first part. + if runtime.GOOS == "windows" && strings.Contains(osUser.Username, "\\") { + osUserName = osUser.Username[strings.LastIndex(osUser.Username, "\\")+1:] + } else { + osUserName = osUser.Username + } + } + + pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE", "PGCONNECT_TIMEOUT", "PGSSLSNI"} + + savedEnv := make(map[string]string) + for _, n := range pgEnvvars { + savedEnv[n] = os.Getenv(n) + } + defer func() { + for k, v := range savedEnv { + err := os.Setenv(k, v) + if err != nil { + t.Fatalf("Unable to restore environment: %v", err) + } + } + }() + + tests := []struct { + name string + envvars map[string]string + config *pgconn.Config + }{ + { + // not testing no environment at all as that would use default host and that can vary. + name: "PGHOST only", + envvars: map[string]string{"PGHOST": "123.123.123.123"}, + config: &pgconn.Config{ + User: osUserName, + Host: "123.123.123.123", + Port: 5432, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "123.123.123.123", + Port: 5432, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "All non-TLS environment", + envvars: map[string]string{ + "PGHOST": "123.123.123.123", + "PGPORT": "7777", + "PGDATABASE": "foo", + "PGUSER": "bar", + "PGPASSWORD": "baz", + "PGCONNECT_TIMEOUT": "10", + "PGSSLMODE": "disable", + "PGAPPNAME": "pgxtest", + }, + config: &pgconn.Config{ + Host: "123.123.123.123", + Port: 7777, + Database: "foo", + User: "bar", + Password: "baz", + ConnectTimeout: 10 * time.Second, + TLSConfig: nil, + RuntimeParams: map[string]string{"application_name": "pgxtest"}, + }, + }, + { + name: "SNI can be disabled via environment variable", + envvars: map[string]string{ + "PGHOST": "test.foo", + "PGSSLMODE": "require", + "PGSSLSNI": "0", + }, + config: &pgconn.Config{ + User: osUserName, + Host: "test.foo", + Port: 5432, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + }, + }, + } + + for i, tt := range tests { + for _, n := range pgEnvvars { + err := os.Unsetenv(n) + require.NoError(t, err) + } + + for k, v := range tt.envvars { + err := os.Setenv(k, v) + require.NoError(t, err) + } + + config, err := pgconn.ParseConfig("") + if !assert.Nilf(t, err, "Test %d (%s)", i, tt.name) { + continue + } + + assertConfigsEqual(t, tt.config, config, fmt.Sprintf("Test %d (%s)", i, tt.name)) + } +} + +func TestParseConfigReadsPgPassfile(t *testing.T) { + t.Parallel() + + tf, err := ioutil.TempFile("", "") + require.NoError(t, err) + + defer tf.Close() + defer os.Remove(tf.Name()) + + _, err = tf.Write([]byte("test1:5432:curlydb:curly:nyuknyuknyuk")) + require.NoError(t, err) + + connString := fmt.Sprintf("postgres://curly@test1:5432/curlydb?sslmode=disable&passfile=%s", tf.Name()) + expected := &pgconn.Config{ + User: "curly", + Password: "nyuknyuknyuk", + Host: "test1", + Port: 5432, + Database: "curlydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + } + + actual, err := pgconn.ParseConfig(connString) + assert.NoError(t, err) + + assertConfigsEqual(t, expected, actual, "passfile") +} + +func TestParseConfigReadsPgServiceFile(t *testing.T) { + t.Parallel() + + tf, err := ioutil.TempFile("", "") + require.NoError(t, err) + + defer tf.Close() + defer os.Remove(tf.Name()) + + _, err = tf.Write([]byte(` +[abc] +host=abc.example.com +port=9999 +dbname=abcdb +user=abcuser + +[def] +host = def.example.com +dbname = defdb +user = defuser +application_name = spaced string +`)) + require.NoError(t, err) + + tests := []struct { + name string + connString string + config *pgconn.Config + }{ + { + name: "abc", + connString: fmt.Sprintf("postgres:///?servicefile=%s&service=%s", tf.Name(), "abc"), + config: &pgconn.Config{ + Host: "abc.example.com", + Database: "abcdb", + User: "abcuser", + Port: 9999, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + ServerName: "abc.example.com", + }, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "abc.example.com", + Port: 9999, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "def", + connString: fmt.Sprintf("postgres:///?servicefile=%s&service=%s", tf.Name(), "def"), + config: &pgconn.Config{ + Host: "def.example.com", + Port: 5432, + Database: "defdb", + User: "defuser", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + ServerName: "def.example.com", + }, + RuntimeParams: map[string]string{"application_name": "spaced string"}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "def.example.com", + Port: 5432, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "conn string has precedence", + connString: fmt.Sprintf("postgres://other.example.com:7777/?servicefile=%s&service=%s&sslmode=disable", tf.Name(), "abc"), + config: &pgconn.Config{ + Host: "other.example.com", + Database: "abcdb", + User: "abcuser", + Port: 7777, + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + } + + for i, tt := range tests { + config, err := pgconn.ParseConfig(tt.connString) + if !assert.NoErrorf(t, err, "Test %d (%s)", i, tt.name) { + continue + } + + assertConfigsEqual(t, tt.config, config, fmt.Sprintf("Test %d (%s)", i, tt.name)) + } +} diff --git a/pgconn/defaults.go b/pgconn/defaults.go new file mode 100644 index 00000000..1dd514ff --- /dev/null +++ b/pgconn/defaults.go @@ -0,0 +1,63 @@ +//go:build !windows +// +build !windows + +package pgconn + +import ( + "os" + "os/user" + "path/filepath" +) + +func defaultSettings() map[string]string { + settings := make(map[string]string) + + settings["host"] = defaultHost() + settings["port"] = "5432" + + // Default to the OS user name. Purposely ignoring err getting user name from + // OS. The client application will simply have to specify the user in that + // case (which they typically will be doing anyway). + user, err := user.Current() + if err == nil { + settings["user"] = user.Username + settings["passfile"] = filepath.Join(user.HomeDir, ".pgpass") + settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf") + sslcert := filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt") + sslkey := filepath.Join(user.HomeDir, ".postgresql", "postgresql.key") + if _, err := os.Stat(sslcert); err == nil { + if _, err := os.Stat(sslkey); err == nil { + // Both the cert and key must be present to use them, or do not use either + settings["sslcert"] = sslcert + settings["sslkey"] = sslkey + } + } + sslrootcert := filepath.Join(user.HomeDir, ".postgresql", "root.crt") + if _, err := os.Stat(sslrootcert); err == nil { + settings["sslrootcert"] = sslrootcert + } + } + + settings["target_session_attrs"] = "any" + + return settings +} + +// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost +// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it +// checks the existence of common locations. +func defaultHost() string { + candidatePaths := []string{ + "/var/run/postgresql", // Debian + "/private/tmp", // OSX - homebrew + "/tmp", // standard PostgreSQL + } + + for _, path := range candidatePaths { + if _, err := os.Stat(path); err == nil { + return path + } + } + + return "localhost" +} diff --git a/pgconn/defaults_windows.go b/pgconn/defaults_windows.go new file mode 100644 index 00000000..33b4a1ff --- /dev/null +++ b/pgconn/defaults_windows.go @@ -0,0 +1,57 @@ +package pgconn + +import ( + "os" + "os/user" + "path/filepath" + "strings" +) + +func defaultSettings() map[string]string { + settings := make(map[string]string) + + settings["host"] = defaultHost() + settings["port"] = "5432" + + // Default to the OS user name. Purposely ignoring err getting user name from + // OS. The client application will simply have to specify the user in that + // case (which they typically will be doing anyway). + user, err := user.Current() + appData := os.Getenv("APPDATA") + if err == nil { + // Windows gives us the username here as `DOMAIN\user` or `LOCALPCNAME\user`, + // but the libpq default is just the `user` portion, so we strip off the first part. + username := user.Username + if strings.Contains(username, "\\") { + username = username[strings.LastIndex(username, "\\")+1:] + } + + settings["user"] = username + settings["passfile"] = filepath.Join(appData, "postgresql", "pgpass.conf") + settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf") + sslcert := filepath.Join(appData, "postgresql", "postgresql.crt") + sslkey := filepath.Join(appData, "postgresql", "postgresql.key") + if _, err := os.Stat(sslcert); err == nil { + if _, err := os.Stat(sslkey); err == nil { + // Both the cert and key must be present to use them, or do not use either + settings["sslcert"] = sslcert + settings["sslkey"] = sslkey + } + } + sslrootcert := filepath.Join(appData, "postgresql", "root.crt") + if _, err := os.Stat(sslrootcert); err == nil { + settings["sslrootcert"] = sslrootcert + } + } + + settings["target_session_attrs"] = "any" + + return settings +} + +// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost +// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it +// checks the existence of common locations. +func defaultHost() string { + return "localhost" +} diff --git a/pgconn/doc.go b/pgconn/doc.go new file mode 100644 index 00000000..e3242cf4 --- /dev/null +++ b/pgconn/doc.go @@ -0,0 +1,34 @@ +// Package pgconn is a low-level PostgreSQL database driver. +/* +pgconn provides lower level access to a PostgreSQL connection than a database/sql or pgx connection. It operates at +nearly the same level is the C library libpq. + +Establishing a Connection + +Use Connect to establish a connection. It accepts a connection string in URL or DSN and will read the environment for +libpq style environment variables. + +Executing a Query + +ExecParams and ExecPrepared execute a single query. They return readers that iterate over each row. The Read method +reads all rows into memory. + +Executing Multiple Queries in a Single Round Trip + +Exec and ExecBatch can execute multiple queries in a single round trip. They return readers that iterate over each query +result. The ReadAll method reads all query results into memory. + +Pipeline Mode + +Pipeline mode allows sending queries without having read the results of previously sent queries. It allows +control of exactly how many and when network round trips occur. + +Context Support + +All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the +method immediately returns. In most circumstances, this will close the underlying connection. + +The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the +client to abort. +*/ +package pgconn diff --git a/pgconn/errors.go b/pgconn/errors.go new file mode 100644 index 00000000..3c54bbec --- /dev/null +++ b/pgconn/errors.go @@ -0,0 +1,226 @@ +package pgconn + +import ( + "context" + "errors" + "fmt" + "net" + "net/url" + "regexp" + "strings" +) + +// SafeToRetry checks if the err is guaranteed to have occurred before sending any data to the server. +func SafeToRetry(err error) bool { + if e, ok := err.(interface{ SafeToRetry() bool }); ok { + return e.SafeToRetry() + } + return false +} + +// Timeout checks if err was was caused by a timeout. To be specific, it is true if err was caused within pgconn by a +// context.DeadlineExceeded or an implementer of net.Error where Timeout() is true. +func Timeout(err error) bool { + var timeoutErr *errTimeout + return errors.As(err, &timeoutErr) +} + +// PgError represents an error reported by the PostgreSQL server. See +// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for +// detailed field description. +type PgError struct { + Severity string + Code string + Message string + Detail string + Hint string + Position int32 + InternalPosition int32 + InternalQuery string + Where string + SchemaName string + TableName string + ColumnName string + DataTypeName string + ConstraintName string + File string + Line int32 + Routine string +} + +func (pe *PgError) Error() string { + return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")" +} + +// SQLState returns the SQLState of the error. +func (pe *PgError) SQLState() string { + return pe.Code +} + +type connectError struct { + config *Config + msg string + err error +} + +func (e *connectError) Error() string { + sb := &strings.Builder{} + fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.config.Host, e.config.User, e.config.Database, e.msg) + if e.err != nil { + fmt.Fprintf(sb, " (%s)", e.err.Error()) + } + return sb.String() +} + +func (e *connectError) Unwrap() error { + return e.err +} + +type connLockError struct { + status string +} + +func (e *connLockError) SafeToRetry() bool { + return true // a lock failure by definition happens before the connection is used. +} + +func (e *connLockError) Error() string { + return e.status +} + +type parseConfigError struct { + connString string + msg string + err error +} + +func (e *parseConfigError) Error() string { + connString := redactPW(e.connString) + if e.err == nil { + return fmt.Sprintf("cannot parse `%s`: %s", connString, e.msg) + } + return fmt.Sprintf("cannot parse `%s`: %s (%s)", connString, e.msg, e.err.Error()) +} + +func (e *parseConfigError) Unwrap() error { + return e.err +} + +func normalizeTimeoutError(ctx context.Context, err error) error { + if err, ok := err.(net.Error); ok && err.Timeout() { + if ctx.Err() == context.Canceled { + // Since the timeout was caused by a context cancellation, the actual error is context.Canceled not the timeout error. + return context.Canceled + } else if ctx.Err() == context.DeadlineExceeded { + return &errTimeout{err: ctx.Err()} + } else { + return &errTimeout{err: err} + } + } + return err +} + +type pgconnError struct { + msg string + err error + safeToRetry bool +} + +func (e *pgconnError) Error() string { + if e.msg == "" { + return e.err.Error() + } + if e.err == nil { + return e.msg + } + return fmt.Sprintf("%s: %s", e.msg, e.err.Error()) +} + +func (e *pgconnError) SafeToRetry() bool { + return e.safeToRetry +} + +func (e *pgconnError) Unwrap() error { + return e.err +} + +// errTimeout occurs when an error was caused by a timeout. Specifically, it wraps an error which is +// context.Canceled, context.DeadlineExceeded, or an implementer of net.Error where Timeout() is true. +type errTimeout struct { + err error +} + +func (e *errTimeout) Error() string { + return fmt.Sprintf("timeout: %s", e.err.Error()) +} + +func (e *errTimeout) SafeToRetry() bool { + return SafeToRetry(e.err) +} + +func (e *errTimeout) Unwrap() error { + return e.err +} + +type contextAlreadyDoneError struct { + err error +} + +func (e *contextAlreadyDoneError) Error() string { + return fmt.Sprintf("context already done: %s", e.err.Error()) +} + +func (e *contextAlreadyDoneError) SafeToRetry() bool { + return true +} + +func (e *contextAlreadyDoneError) Unwrap() error { + return e.err +} + +// newContextAlreadyDoneError double-wraps a context error in `contextAlreadyDoneError` and `errTimeout`. +func newContextAlreadyDoneError(ctx context.Context) (err error) { + return &errTimeout{&contextAlreadyDoneError{err: ctx.Err()}} +} + +func redactPW(connString string) string { + if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") { + if u, err := url.Parse(connString); err == nil { + return redactURL(u) + } + } + quotedDSN := regexp.MustCompile(`password='[^']*'`) + connString = quotedDSN.ReplaceAllLiteralString(connString, "password=xxxxx") + plainDSN := regexp.MustCompile(`password=[^ ]*`) + connString = plainDSN.ReplaceAllLiteralString(connString, "password=xxxxx") + brokenURL := regexp.MustCompile(`:[^:@]+?@`) + connString = brokenURL.ReplaceAllLiteralString(connString, ":xxxxxx@") + return connString +} + +func redactURL(u *url.URL) string { + if u == nil { + return "" + } + if _, pwSet := u.User.Password(); pwSet { + u.User = url.UserPassword(u.User.Username(), "xxxxx") + } + return u.String() +} + +type NotPreferredError struct { + err error + safeToRetry bool +} + +func (e *NotPreferredError) Error() string { + return fmt.Sprintf("standby server not found: %s", e.err.Error()) +} + +func (e *NotPreferredError) SafeToRetry() bool { + return e.safeToRetry +} + +func (e *NotPreferredError) Unwrap() error { + return e.err +} diff --git a/pgconn/errors_test.go b/pgconn/errors_test.go new file mode 100644 index 00000000..9d559346 --- /dev/null +++ b/pgconn/errors_test.go @@ -0,0 +1,54 @@ +package pgconn_test + +import ( + "testing" + + "github.com/jackc/pgx/v5/pgconn" + "github.com/stretchr/testify/assert" +) + +func TestConfigError(t *testing.T) { + tests := []struct { + name string + err error + expectedMsg string + }{ + { + name: "url with password", + err: pgconn.NewParseConfigError("postgresql://foo:password@host", "msg", nil), + expectedMsg: "cannot parse `postgresql://foo:xxxxx@host`: msg", + }, + { + name: "dsn with password unquoted", + err: pgconn.NewParseConfigError("host=host password=password user=user", "msg", nil), + expectedMsg: "cannot parse `host=host password=xxxxx user=user`: msg", + }, + { + name: "dsn with password quoted", + err: pgconn.NewParseConfigError("host=host password='pass word' user=user", "msg", nil), + expectedMsg: "cannot parse `host=host password=xxxxx user=user`: msg", + }, + { + name: "weird url", + err: pgconn.NewParseConfigError("postgresql://foo::pasword@host:1:", "msg", nil), + expectedMsg: "cannot parse `postgresql://foo:xxxxx@host:1:`: msg", + }, + { + name: "weird url with slash in password", + err: pgconn.NewParseConfigError("postgres://user:pass/word@host:5432/db_name", "msg", nil), + expectedMsg: "cannot parse `postgres://user:xxxxxx@host:5432/db_name`: msg", + }, + { + name: "url without password", + err: pgconn.NewParseConfigError("postgresql://other@host/db", "msg", nil), + expectedMsg: "cannot parse `postgresql://other@host/db`: msg", + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.EqualError(t, tt.err, tt.expectedMsg) + }) + } +} diff --git a/pgconn/export_test.go b/pgconn/export_test.go new file mode 100644 index 00000000..2a0bad8b --- /dev/null +++ b/pgconn/export_test.go @@ -0,0 +1,11 @@ +// File export_test exports some methods for better testing. + +package pgconn + +func NewParseConfigError(conn, msg string, err error) error { + return &parseConfigError{ + connString: conn, + msg: msg, + err: err, + } +} diff --git a/pgconn/helper_test.go b/pgconn/helper_test.go new file mode 100644 index 00000000..0696f4ce --- /dev/null +++ b/pgconn/helper_test.go @@ -0,0 +1,36 @@ +package pgconn_test + +import ( + "context" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgconn" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func closeConn(t testing.TB, conn *pgconn.PgConn) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + require.NoError(t, conn.Close(ctx)) + select { + case <-conn.CleanupDone(): + case <-time.After(5 * time.Second): + t.Fatal("Connection cleanup exceeded maximum time") + } +} + +// Do a simple query to ensure the connection is still usable +func ensureConnValid(t *testing.T, pgConn *pgconn.PgConn) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + result := pgConn.ExecParams(ctx, "select generate_series(1,$1)", [][]byte{[]byte("3")}, nil, nil, nil).Read() + cancel() + + require.Nil(t, result.Err) + assert.Equal(t, 3, len(result.Rows)) + assert.Equal(t, "1", string(result.Rows[0][0])) + assert.Equal(t, "2", string(result.Rows[1][0])) + assert.Equal(t, "3", string(result.Rows[2][0])) +} diff --git a/pgconn/internal/ctxwatch/context_watcher.go b/pgconn/internal/ctxwatch/context_watcher.go new file mode 100644 index 00000000..b39cb3ee --- /dev/null +++ b/pgconn/internal/ctxwatch/context_watcher.go @@ -0,0 +1,73 @@ +package ctxwatch + +import ( + "context" + "sync" +) + +// ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a +// time. +type ContextWatcher struct { + onCancel func() + onUnwatchAfterCancel func() + unwatchChan chan struct{} + + lock sync.Mutex + watchInProgress bool + onCancelWasCalled bool +} + +// NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled. +// OnUnwatchAfterCancel will be called when Unwatch is called and the watched context had already been canceled and +// onCancel called. +func NewContextWatcher(onCancel func(), onUnwatchAfterCancel func()) *ContextWatcher { + cw := &ContextWatcher{ + onCancel: onCancel, + onUnwatchAfterCancel: onUnwatchAfterCancel, + unwatchChan: make(chan struct{}), + } + + return cw +} + +// Watch starts watching ctx. If ctx is canceled then the onCancel function passed to NewContextWatcher will be called. +func (cw *ContextWatcher) Watch(ctx context.Context) { + cw.lock.Lock() + defer cw.lock.Unlock() + + if cw.watchInProgress { + panic("Watch already in progress") + } + + cw.onCancelWasCalled = false + + if ctx.Done() != nil { + cw.watchInProgress = true + go func() { + select { + case <-ctx.Done(): + cw.onCancel() + cw.onCancelWasCalled = true + <-cw.unwatchChan + case <-cw.unwatchChan: + } + }() + } else { + cw.watchInProgress = false + } +} + +// Unwatch stops watching the previously watched context. If the onCancel function passed to NewContextWatcher was +// called then onUnwatchAfterCancel will also be called. +func (cw *ContextWatcher) Unwatch() { + cw.lock.Lock() + defer cw.lock.Unlock() + + if cw.watchInProgress { + cw.unwatchChan <- struct{}{} + if cw.onCancelWasCalled { + cw.onUnwatchAfterCancel() + } + cw.watchInProgress = false + } +} diff --git a/pgconn/internal/ctxwatch/context_watcher_test.go b/pgconn/internal/ctxwatch/context_watcher_test.go new file mode 100644 index 00000000..39652995 --- /dev/null +++ b/pgconn/internal/ctxwatch/context_watcher_test.go @@ -0,0 +1,163 @@ +package ctxwatch_test + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgconn/internal/ctxwatch" + "github.com/stretchr/testify/require" +) + +func TestContextWatcherContextCancelled(t *testing.T) { + canceledChan := make(chan struct{}) + cleanupCalled := false + cw := ctxwatch.NewContextWatcher(func() { + canceledChan <- struct{}{} + }, func() { + cleanupCalled = true + }) + + ctx, cancel := context.WithCancel(context.Background()) + cw.Watch(ctx) + cancel() + + select { + case <-canceledChan: + case <-time.NewTimer(time.Second).C: + t.Fatal("Timed out waiting for cancel func to be called") + } + + cw.Unwatch() + + require.True(t, cleanupCalled, "Cleanup func was not called") +} + +func TestContextWatcherUnwatchdBeforeContextCancelled(t *testing.T) { + cw := ctxwatch.NewContextWatcher(func() { + t.Error("cancel func should not have been called") + }, func() { + t.Error("cleanup func should not have been called") + }) + + ctx, cancel := context.WithCancel(context.Background()) + cw.Watch(ctx) + cw.Unwatch() + cancel() +} + +func TestContextWatcherMultipleWatchPanics(t *testing.T) { + cw := ctxwatch.NewContextWatcher(func() {}, func() {}) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cw.Watch(ctx) + + ctx2, cancel2 := context.WithCancel(context.Background()) + defer cancel2() + require.Panics(t, func() { cw.Watch(ctx2) }, "Expected panic when Watch called multiple times") +} + +func TestContextWatcherUnwatchWhenNotWatchingIsSafe(t *testing.T) { + cw := ctxwatch.NewContextWatcher(func() {}, func() {}) + cw.Unwatch() // unwatch when not / never watching + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cw.Watch(ctx) + cw.Unwatch() + cw.Unwatch() // double unwatch +} + +func TestContextWatcherUnwatchIsConcurrencySafe(t *testing.T) { + cw := ctxwatch.NewContextWatcher(func() {}, func() {}) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + cw.Watch(ctx) + + go cw.Unwatch() + go cw.Unwatch() + + <-ctx.Done() +} + +func TestContextWatcherStress(t *testing.T) { + var cancelFuncCalls int64 + var cleanupFuncCalls int64 + + cw := ctxwatch.NewContextWatcher(func() { + atomic.AddInt64(&cancelFuncCalls, 1) + }, func() { + atomic.AddInt64(&cleanupFuncCalls, 1) + }) + + cycleCount := 100000 + + for i := 0; i < cycleCount; i++ { + ctx, cancel := context.WithCancel(context.Background()) + cw.Watch(ctx) + if i%2 == 0 { + cancel() + } + + // Without time.Sleep, cw.Unwatch will almost always run before the cancel func which means cancel will never happen. This gives us a better mix. + if i%3 == 0 { + time.Sleep(time.Nanosecond) + } + + cw.Unwatch() + if i%2 == 1 { + cancel() + } + } + + actualCancelFuncCalls := atomic.LoadInt64(&cancelFuncCalls) + actualCleanupFuncCalls := atomic.LoadInt64(&cleanupFuncCalls) + + if actualCancelFuncCalls == 0 { + t.Fatal("actualCancelFuncCalls == 0") + } + + maxCancelFuncCalls := int64(cycleCount) / 2 + if actualCancelFuncCalls > maxCancelFuncCalls { + t.Errorf("cancel func calls should be no more than %d but was %d", actualCancelFuncCalls, maxCancelFuncCalls) + } + + if actualCancelFuncCalls != actualCleanupFuncCalls { + t.Errorf("cancel func calls (%d) should be equal to cleanup func calls (%d) but was not", actualCancelFuncCalls, actualCleanupFuncCalls) + } +} + +func BenchmarkContextWatcherUncancellable(b *testing.B) { + cw := ctxwatch.NewContextWatcher(func() {}, func() {}) + + for i := 0; i < b.N; i++ { + cw.Watch(context.Background()) + cw.Unwatch() + } +} + +func BenchmarkContextWatcherCancelled(b *testing.B) { + cw := ctxwatch.NewContextWatcher(func() {}, func() {}) + + for i := 0; i < b.N; i++ { + ctx, cancel := context.WithCancel(context.Background()) + cw.Watch(ctx) + cancel() + cw.Unwatch() + } +} + +func BenchmarkContextWatcherCancellable(b *testing.B) { + cw := ctxwatch.NewContextWatcher(func() {}, func() {}) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + for i := 0; i < b.N; i++ { + cw.Watch(ctx) + cw.Unwatch() + } +} diff --git a/pgconn/krb5.go b/pgconn/krb5.go new file mode 100644 index 00000000..969675fd --- /dev/null +++ b/pgconn/krb5.go @@ -0,0 +1,100 @@ +package pgconn + +import ( + "errors" + "fmt" + + "github.com/jackc/pgx/v5/pgproto3" +) + +// NewGSSFunc creates a GSS authentication provider, for use with +// RegisterGSSProvider. +type NewGSSFunc func() (GSS, error) + +var newGSS NewGSSFunc + +// RegisterGSSProvider registers a GSS authentication provider. For example, if +// you need to use Kerberos to authenticate with your server, add this to your +// main package: +// +// import "github.com/otan/gopgkrb5" +// +// func init() { +// pgconn.RegisterGSSProvider(func() (pgconn.GSS, error) { return gopgkrb5.NewGSS() }) +// } +func RegisterGSSProvider(newGSSArg NewGSSFunc) { + newGSS = newGSSArg +} + +// GSS provides GSSAPI authentication (e.g., Kerberos). +type GSS interface { + GetInitToken(host string, service string) ([]byte, error) + GetInitTokenFromSPN(spn string) ([]byte, error) + Continue(inToken []byte) (done bool, outToken []byte, err error) +} + +func (c *PgConn) gssAuth() error { + if newGSS == nil { + return errors.New("kerberos error: no GSSAPI provider registered, see https://github.com/otan/gopgkrb5") + } + cli, err := newGSS() + if err != nil { + return err + } + + var nextData []byte + if c.config.KerberosSpn != "" { + // Use the supplied SPN if provided. + nextData, err = cli.GetInitTokenFromSPN(c.config.KerberosSpn) + } else { + // Allow the kerberos service name to be overridden + service := "postgres" + if c.config.KerberosSrvName != "" { + service = c.config.KerberosSrvName + } + nextData, err = cli.GetInitToken(c.config.Host, service) + } + if err != nil { + return err + } + + for { + gssResponse := &pgproto3.GSSResponse{ + Data: nextData, + } + c.frontend.Send(gssResponse) + err = c.frontend.Flush() + if err != nil { + return err + } + resp, err := c.rxGSSContinue() + if err != nil { + return err + } + var done bool + done, nextData, err = cli.Continue(resp.Data) + if err != nil { + return err + } + if done { + break + } + } + return nil +} + +func (c *PgConn) rxGSSContinue() (*pgproto3.AuthenticationGSSContinue, error) { + msg, err := c.receiveMessage() + if err != nil { + return nil, err + } + + switch m := msg.(type) { + case *pgproto3.AuthenticationGSSContinue: + return m, nil + case *pgproto3.ErrorResponse: + return nil, ErrorResponseToPgError(m) + } + + return nil, fmt.Errorf("expected AuthenticationGSSContinue message but received unexpected message %T", msg) +} diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go new file mode 100644 index 00000000..59fa35c6 --- /dev/null +++ b/pgconn/pgconn.go @@ -0,0 +1,1962 @@ +package pgconn + +import ( + "context" + "crypto/md5" + "crypto/tls" + "encoding/binary" + "encoding/hex" + "errors" + "fmt" + "io" + "math" + "net" + "strconv" + "strings" + "time" + + "github.com/jackc/pgx/v5/internal/iobufpool" + "github.com/jackc/pgx/v5/internal/nbconn" + "github.com/jackc/pgx/v5/internal/pgio" + "github.com/jackc/pgx/v5/pgconn/internal/ctxwatch" + "github.com/jackc/pgx/v5/pgproto3" +) + +const ( + connStatusUninitialized = iota + connStatusConnecting + connStatusClosed + connStatusIdle + connStatusBusy +) + +// Notice represents a notice response message reported by the PostgreSQL server. Be aware that this is distinct from +// LISTEN/NOTIFY notification. +type Notice PgError + +// Notification is a message received from the PostgreSQL LISTEN/NOTIFY system +type Notification struct { + PID uint32 // backend pid that sent the notification + Channel string // channel from which notification was received + Payload string +} + +// DialFunc is a function that can be used to connect to a PostgreSQL server. +type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) + +// LookupFunc is a function that can be used to lookup IPs addrs from host. Optionally an ip:port combination can be +// returned in order to override the connection string's port. +type LookupFunc func(ctx context.Context, host string) (addrs []string, err error) + +// BuildFrontendFunc is a function that can be used to create Frontend implementation for connection. +type BuildFrontendFunc func(r io.Reader, w io.Writer) *pgproto3.Frontend + +// NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at +// any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin +// of the notice, but it must not invoke any query method. Be aware that this is distinct from LISTEN/NOTIFY +// notification. +type NoticeHandler func(*PgConn, *Notice) + +// NotificationHandler is a function that can handle notifications received from the PostgreSQL server. Notifications +// can be received at any time, usually during handling of a query response. The *PgConn is provided so the handler is +// aware of the origin of the notice, but it must not invoke any query method. Be aware that this is distinct from a +// notice event. +type NotificationHandler func(*PgConn, *Notification) + +// PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. +type PgConn struct { + conn nbconn.Conn // the non-blocking wrapper for the underlying TCP or unix domain socket connection + pid uint32 // backend pid + secretKey uint32 // key to use to send a cancel query message to the server + parameterStatuses map[string]string // parameters that have been reported by the server + txStatus byte + frontend *pgproto3.Frontend + + config *Config + + status byte // One of connStatus* constants + + peekedMsg pgproto3.BackendMessage + + // Reusable / preallocated resources + resultReader ResultReader + multiResultReader MultiResultReader + pipeline Pipeline + contextWatcher *ctxwatch.ContextWatcher + fieldDescriptions [16]FieldDescription + + cleanupDone chan struct{} +} + +// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) +// to provide configuration. See documentation for ParseConfig for details. ctx can be used to cancel a connect attempt. +func Connect(ctx context.Context, connString string) (*PgConn, error) { + config, err := ParseConfig(connString) + if err != nil { + return nil, err + } + + return ConnectConfig(ctx, config) +} + +// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) +// and ParseConfigOptions to provide additional configuration. See documentation for ParseConfig for details. ctx can be +// used to cancel a connect attempt. +func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptions ParseConfigOptions) (*PgConn, error) { + config, err := ParseConfigWithOptions(connString, parseConfigOptions) + if err != nil { + return nil, err + } + + return ConnectConfig(ctx, config) +} + +// Connect establishes a connection to a PostgreSQL server using config. config must have been constructed with +// ParseConfig. ctx can be used to cancel a connect attempt. +// +// If config.Fallbacks are present they will sequentially be tried in case of error establishing network connection. An +// authentication error will terminate the chain of attempts (like libpq: +// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error. Otherwise, +// if all attempts fail the last error is returned. +func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err error) { + // Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from + // zero values. + if !config.createdByParseConfig { + panic("config must be created by ParseConfig") + } + + // Simplify usage by treating primary config and fallbacks the same. + fallbackConfigs := []*FallbackConfig{ + { + Host: config.Host, + Port: config.Port, + TLSConfig: config.TLSConfig, + }, + } + fallbackConfigs = append(fallbackConfigs, config.Fallbacks...) + ctx := octx + fallbackConfigs, err = expandWithIPs(ctx, config.LookupFunc, fallbackConfigs) + if err != nil { + return nil, &connectError{config: config, msg: "hostname resolving error", err: err} + } + + if len(fallbackConfigs) == 0 { + return nil, &connectError{config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")} + } + + foundBestServer := false + var fallbackConfig *FallbackConfig + for _, fc := range fallbackConfigs { + // ConnectTimeout restricts the whole connection process. + if config.ConnectTimeout != 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(octx, config.ConnectTimeout) + defer cancel() + } else { + ctx = octx + } + pgConn, err = connect(ctx, config, fc, false) + if err == nil { + foundBestServer = true + break + } else if pgerr, ok := err.(*PgError); ok { + err = &connectError{config: config, msg: "server error", err: pgerr} + const ERRCODE_INVALID_PASSWORD = "28P01" // wrong password + const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000" // wrong password or bad pg_hba.conf settings + const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist + const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501" // missing connect privilege + if pgerr.Code == ERRCODE_INVALID_PASSWORD || + pgerr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION || + pgerr.Code == ERRCODE_INVALID_CATALOG_NAME || + pgerr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE { + break + } + } else if cerr, ok := err.(*connectError); ok { + if _, ok := cerr.err.(*NotPreferredError); ok { + fallbackConfig = fc + } + } + } + + if !foundBestServer && fallbackConfig != nil { + pgConn, err = connect(ctx, config, fallbackConfig, true) + if pgerr, ok := err.(*PgError); ok { + err = &connectError{config: config, msg: "server error", err: pgerr} + } + } + + if err != nil { + return nil, err // no need to wrap in connectError because it will already be wrapped in all cases except PgError + } + + if config.AfterConnect != nil { + err := config.AfterConnect(ctx, pgConn) + if err != nil { + pgConn.conn.Close() + return nil, &connectError{config: config, msg: "AfterConnect error", err: err} + } + } + + return pgConn, nil +} + +func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*FallbackConfig) ([]*FallbackConfig, error) { + var configs []*FallbackConfig + + for _, fb := range fallbacks { + // skip resolve for unix sockets + if isAbsolutePath(fb.Host) { + configs = append(configs, &FallbackConfig{ + Host: fb.Host, + Port: fb.Port, + TLSConfig: fb.TLSConfig, + }) + + continue + } + + ips, err := lookupFn(ctx, fb.Host) + if err != nil { + return nil, err + } + + for _, ip := range ips { + splitIP, splitPort, err := net.SplitHostPort(ip) + if err == nil { + port, err := strconv.ParseUint(splitPort, 10, 16) + if err != nil { + return nil, fmt.Errorf("error parsing port (%s) from lookup: %w", splitPort, err) + } + configs = append(configs, &FallbackConfig{ + Host: splitIP, + Port: uint16(port), + TLSConfig: fb.TLSConfig, + }) + } else { + configs = append(configs, &FallbackConfig{ + Host: ip, + Port: fb.Port, + TLSConfig: fb.TLSConfig, + }) + } + } + } + + return configs, nil +} + +func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig, + ignoreNotPreferredErr bool) (*PgConn, error) { + pgConn := new(PgConn) + pgConn.config = config + pgConn.cleanupDone = make(chan struct{}) + + var err error + network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) + netConn, err := config.DialFunc(ctx, network, address) + if err != nil { + return nil, &connectError{config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)} + } + nbNetConn := nbconn.NewNetConn(netConn, false) + + pgConn.conn = nbNetConn + pgConn.contextWatcher = newContextWatcher(nbNetConn) + pgConn.contextWatcher.Watch(ctx) + + if fallbackConfig.TLSConfig != nil { + nbTLSConn, err := startTLS(nbNetConn, fallbackConfig.TLSConfig) + pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS. + if err != nil { + netConn.Close() + return nil, &connectError{config: config, msg: "tls error", err: err} + } + + pgConn.conn = nbTLSConn + pgConn.contextWatcher = newContextWatcher(nbTLSConn) + pgConn.contextWatcher.Watch(ctx) + } + + defer pgConn.contextWatcher.Unwatch() + + pgConn.parameterStatuses = make(map[string]string) + pgConn.status = connStatusConnecting + pgConn.frontend = config.BuildFrontend(pgConn.conn, pgConn.conn) + + startupMsg := pgproto3.StartupMessage{ + ProtocolVersion: pgproto3.ProtocolVersionNumber, + Parameters: make(map[string]string), + } + + // Copy default run-time params + for k, v := range config.RuntimeParams { + startupMsg.Parameters[k] = v + } + + startupMsg.Parameters["user"] = config.User + if config.Database != "" { + startupMsg.Parameters["database"] = config.Database + } + + pgConn.frontend.Send(&startupMsg) + if err := pgConn.frontend.Flush(); err != nil { + pgConn.conn.Close() + return nil, &connectError{config: config, msg: "failed to write startup message", err: err} + } + + for { + msg, err := pgConn.receiveMessage() + if err != nil { + pgConn.conn.Close() + if err, ok := err.(*PgError); ok { + return nil, err + } + return nil, &connectError{config: config, msg: "failed to receive message", err: normalizeTimeoutError(ctx, err)} + } + + switch msg := msg.(type) { + case *pgproto3.BackendKeyData: + pgConn.pid = msg.ProcessID + pgConn.secretKey = msg.SecretKey + + case *pgproto3.AuthenticationOk: + case *pgproto3.AuthenticationCleartextPassword: + err = pgConn.txPasswordMessage(pgConn.config.Password) + if err != nil { + pgConn.conn.Close() + return nil, &connectError{config: config, msg: "failed to write password message", err: err} + } + case *pgproto3.AuthenticationMD5Password: + digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:])) + err = pgConn.txPasswordMessage(digestedPassword) + if err != nil { + pgConn.conn.Close() + return nil, &connectError{config: config, msg: "failed to write password message", err: err} + } + case *pgproto3.AuthenticationSASL: + err = pgConn.scramAuth(msg.AuthMechanisms) + if err != nil { + pgConn.conn.Close() + return nil, &connectError{config: config, msg: "failed SASL auth", err: err} + } + case *pgproto3.AuthenticationGSS: + err = pgConn.gssAuth() + if err != nil { + pgConn.conn.Close() + return nil, &connectError{config: config, msg: "failed GSS auth", err: err} + } + case *pgproto3.ReadyForQuery: + pgConn.status = connStatusIdle + if config.ValidateConnect != nil { + // ValidateConnect may execute commands that cause the context to be watched again. Unwatch first to avoid + // the watch already in progress panic. This is that last thing done by this method so there is no need to + // restart the watch after ValidateConnect returns. + // + // See https://github.com/jackc/pgconn/issues/40. + pgConn.contextWatcher.Unwatch() + + err := config.ValidateConnect(ctx, pgConn) + if err != nil { + if _, ok := err.(*NotPreferredError); ignoreNotPreferredErr && ok { + return pgConn, nil + } + pgConn.conn.Close() + return nil, &connectError{config: config, msg: "ValidateConnect failed", err: err} + } + } + return pgConn, nil + case *pgproto3.ParameterStatus, *pgproto3.NoticeResponse: + // handled by ReceiveMessage + case *pgproto3.ErrorResponse: + pgConn.conn.Close() + return nil, ErrorResponseToPgError(msg) + default: + pgConn.conn.Close() + return nil, &connectError{config: config, msg: "received unexpected message", err: err} + } + } +} + +func newContextWatcher(conn net.Conn) *ctxwatch.ContextWatcher { + return ctxwatch.NewContextWatcher( + func() { conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, + func() { conn.SetDeadline(time.Time{}) }, + ) +} + +func startTLS(conn *nbconn.NetConn, tlsConfig *tls.Config) (*nbconn.TLSConn, error) { + err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103}) + if err != nil { + return nil, err + } + + response := make([]byte, 1) + if _, err = io.ReadFull(conn, response); err != nil { + return nil, err + } + + if response[0] != 'S' { + return nil, errors.New("server refused TLS connection") + } + + tlsConn, err := nbconn.TLSClient(conn, tlsConfig) + if err != nil { + return nil, err + } + + return tlsConn, nil +} + +func (pgConn *PgConn) txPasswordMessage(password string) (err error) { + pgConn.frontend.Send(&pgproto3.PasswordMessage{Password: password}) + return pgConn.frontend.Flush() +} + +func hexMD5(s string) string { + hash := md5.New() + io.WriteString(hash, s) + return hex.EncodeToString(hash.Sum(nil)) +} + +// ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the +// connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages +// are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger +// the OnNotification callback. +// +// This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly. +// See https://www.postgresql.org/docs/current/protocol.html. +func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessage, error) { + if err := pgConn.lock(); err != nil { + return nil, err + } + defer pgConn.unlock() + + if ctx != context.Background() { + select { + case <-ctx.Done(): + return nil, newContextAlreadyDoneError(ctx) + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + } + + msg, err := pgConn.receiveMessage() + if err != nil { + err = &pgconnError{ + msg: "receive message failed", + err: normalizeTimeoutError(ctx, err), + safeToRetry: true} + } + return msg, err +} + +// peekMessage peeks at the next message without setting up context cancellation. +func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) { + if pgConn.peekedMsg != nil { + return pgConn.peekedMsg, nil + } + + msg, err := pgConn.frontend.Receive() + + if err != nil { + if errors.Is(err, nbconn.ErrWouldBlock) { + return nil, err + } + + // Close on anything other than timeout error - everything else is fatal + var netErr net.Error + isNetErr := errors.As(err, &netErr) + if !(isNetErr && netErr.Timeout()) { + pgConn.asyncClose() + } + + return nil, err + } + + pgConn.peekedMsg = msg + return msg, nil +} + +// receiveMessage receives a message without setting up context cancellation +func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { + msg, err := pgConn.peekMessage() + if err != nil { + return nil, err + } + pgConn.peekedMsg = nil + + switch msg := msg.(type) { + case *pgproto3.ReadyForQuery: + pgConn.txStatus = msg.TxStatus + case *pgproto3.ParameterStatus: + pgConn.parameterStatuses[msg.Name] = msg.Value + case *pgproto3.ErrorResponse: + if msg.Severity == "FATAL" { + pgConn.status = connStatusClosed + pgConn.conn.Close() // Ignore error as the connection is already broken and there is already an error to return. + close(pgConn.cleanupDone) + return nil, ErrorResponseToPgError(msg) + } + case *pgproto3.NoticeResponse: + if pgConn.config.OnNotice != nil { + pgConn.config.OnNotice(pgConn, noticeResponseToNotice(msg)) + } + case *pgproto3.NotificationResponse: + if pgConn.config.OnNotification != nil { + pgConn.config.OnNotification(pgConn, &Notification{PID: msg.PID, Channel: msg.Channel, Payload: msg.Payload}) + } + } + + return msg, nil +} + +// Conn returns the underlying net.Conn. This rarely necessary. +func (pgConn *PgConn) Conn() net.Conn { + return pgConn.conn +} + +// PID returns the backend PID. +func (pgConn *PgConn) PID() uint32 { + return pgConn.pid +} + +// TxStatus returns the current TxStatus as reported by the server in the ReadyForQuery message. +// +// Possible return values: +// +// 'I' - idle / not in transaction +// 'T' - in a transaction +// 'E' - in a failed transaction +// +// See https://www.postgresql.org/docs/current/protocol-message-formats.html. +func (pgConn *PgConn) TxStatus() byte { + return pgConn.txStatus +} + +// SecretKey returns the backend secret key used to send a cancel query message to the server. +func (pgConn *PgConn) SecretKey() uint32 { + return pgConn.secretKey +} + +// Frontend returns the underlying *pgproto3.Frontend. This rarely necessary. +func (pgConn *PgConn) Frontend() *pgproto3.Frontend { + return pgConn.frontend +} + +// Close closes a connection. It is safe to call Close on a already closed connection. Close attempts a clean close by +// sending the exit message to PostgreSQL. However, this could block so ctx is available to limit the time to wait. The +// underlying net.Conn.Close() will always be called regardless of any other errors. +func (pgConn *PgConn) Close(ctx context.Context) error { + if pgConn.status == connStatusClosed { + return nil + } + pgConn.status = connStatusClosed + + defer close(pgConn.cleanupDone) + defer pgConn.conn.Close() + + if ctx != context.Background() { + // Close may be called while a cancellable query is in progress. This will most often be triggered by panic when + // a defer closes the connection (possibly indirectly via a transaction or a connection pool). Unwatch to end any + // previous watch. It is safe to Unwatch regardless of whether a watch is already is progress. + // + // See https://github.com/jackc/pgconn/issues/29 + pgConn.contextWatcher.Unwatch() + + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + } + + // Ignore any errors sending Terminate message and waiting for server to close connection. + // This mimics the behavior of libpq PQfinish. It calls closePGconn which calls sendTerminateConn which purposefully + // ignores errors. + // + // See https://github.com/jackc/pgx/issues/637 + pgConn.frontend.Send(&pgproto3.Terminate{}) + pgConn.frontend.Flush() + + return pgConn.conn.Close() +} + +// asyncClose marks the connection as closed and asynchronously sends a cancel query message and closes the underlying +// connection. +func (pgConn *PgConn) asyncClose() { + if pgConn.status == connStatusClosed { + return + } + pgConn.status = connStatusClosed + + go func() { + defer close(pgConn.cleanupDone) + defer pgConn.conn.Close() + + deadline := time.Now().Add(time.Second * 15) + + ctx, cancel := context.WithDeadline(context.Background(), deadline) + defer cancel() + + pgConn.CancelRequest(ctx) + + pgConn.conn.SetDeadline(deadline) + + pgConn.frontend.Send(&pgproto3.Terminate{}) + pgConn.frontend.Flush() + }() +} + +// CleanupDone returns a channel that will be closed after all underlying resources have been cleaned up. A closed +// connection is no longer usable, but underlying resources, in particular the net.Conn, may not have finished closing +// yet. This is because certain errors such as a context cancellation require that the interrupted function call return +// immediately, but the error may also cause the connection to be closed. In these cases the underlying resources are +// closed asynchronously. +// +// This is only likely to be useful to connection pools. It gives them a way avoid establishing a new connection while +// an old connection is still being cleaned up and thereby exceeding the maximum pool size. +func (pgConn *PgConn) CleanupDone() chan (struct{}) { + return pgConn.cleanupDone +} + +// IsClosed reports if the connection has been closed. +// +// CleanupDone() can be used to determine if all cleanup has been completed. +func (pgConn *PgConn) IsClosed() bool { + return pgConn.status < connStatusIdle +} + +// IsBusy reports if the connection is busy. +func (pgConn *PgConn) IsBusy() bool { + return pgConn.status == connStatusBusy +} + +// lock locks the connection. +func (pgConn *PgConn) lock() error { + switch pgConn.status { + case connStatusBusy: + return &connLockError{status: "conn busy"} // This only should be possible in case of an application bug. + case connStatusClosed: + return &connLockError{status: "conn closed"} + case connStatusUninitialized: + return &connLockError{status: "conn uninitialized"} + } + pgConn.status = connStatusBusy + return nil +} + +func (pgConn *PgConn) unlock() { + switch pgConn.status { + case connStatusBusy: + pgConn.status = connStatusIdle + case connStatusClosed: + default: + panic("BUG: cannot unlock unlocked connection") // This should only be possible if there is a bug in this package. + } +} + +// ParameterStatus returns the value of a parameter reported by the server (e.g. +// server_version). Returns an empty string for unknown parameters. +func (pgConn *PgConn) ParameterStatus(key string) string { + return pgConn.parameterStatuses[key] +} + +// CommandTag is the status text returned by PostgreSQL for a query. +type CommandTag struct { + s string +} + +// NewCommandTag makes a CommandTag from s. +func NewCommandTag(s string) CommandTag { + return CommandTag{s: s} +} + +// RowsAffected returns the number of rows affected. If the CommandTag was not +// for a row affecting command (e.g. "CREATE TABLE") then it returns 0. +func (ct CommandTag) RowsAffected() int64 { + // Find last non-digit + idx := -1 + for i := len(ct.s) - 1; i >= 0; i-- { + if ct.s[i] >= '0' && ct.s[i] <= '9' { + idx = i + } else { + break + } + } + + if idx == -1 { + return 0 + } + + var n int64 + for _, b := range ct.s[idx:] { + n = n*10 + int64(b-'0') + } + + return n +} + +func (ct CommandTag) String() string { + return ct.s +} + +// Insert is true if the command tag starts with "INSERT". +func (ct CommandTag) Insert() bool { + return strings.HasPrefix(ct.s, "INSERT") +} + +// Update is true if the command tag starts with "UPDATE". +func (ct CommandTag) Update() bool { + return strings.HasPrefix(ct.s, "UPDATE") +} + +// Delete is true if the command tag starts with "DELETE". +func (ct CommandTag) Delete() bool { + return strings.HasPrefix(ct.s, "DELETE") +} + +// Select is true if the command tag starts with "SELECT". +func (ct CommandTag) Select() bool { + return strings.HasPrefix(ct.s, "SELECT") +} + +type FieldDescription struct { + Name string + TableOID uint32 + TableAttributeNumber uint16 + DataTypeOID uint32 + DataTypeSize int16 + TypeModifier int32 + Format int16 +} + +func (pgConn *PgConn) convertRowDescription(dst []FieldDescription, rd *pgproto3.RowDescription) []FieldDescription { + if cap(dst) >= len(rd.Fields) { + dst = dst[:len(rd.Fields):len(rd.Fields)] + } else { + dst = make([]FieldDescription, len(rd.Fields)) + } + + for i := range rd.Fields { + dst[i].Name = string(rd.Fields[i].Name) + dst[i].TableOID = rd.Fields[i].TableOID + dst[i].TableAttributeNumber = rd.Fields[i].TableAttributeNumber + dst[i].DataTypeOID = rd.Fields[i].DataTypeOID + dst[i].DataTypeSize = rd.Fields[i].DataTypeSize + dst[i].TypeModifier = rd.Fields[i].TypeModifier + dst[i].Format = rd.Fields[i].Format + } + + return dst +} + +type StatementDescription struct { + Name string + SQL string + ParamOIDs []uint32 + Fields []FieldDescription +} + +// Prepare creates a prepared statement. If the name is empty, the anonymous prepared statement will be used. This +// allows Prepare to also to describe statements without creating a server-side prepared statement. +func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*StatementDescription, error) { + if err := pgConn.lock(); err != nil { + return nil, err + } + defer pgConn.unlock() + + if ctx != context.Background() { + select { + case <-ctx.Done(): + return nil, newContextAlreadyDoneError(ctx) + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + } + + pgConn.frontend.SendParse(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}) + pgConn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: name}) + pgConn.frontend.SendSync(&pgproto3.Sync{}) + err := pgConn.frontend.Flush() + if err != nil { + pgConn.asyncClose() + return nil, err + } + + psd := &StatementDescription{Name: name, SQL: sql} + + var parseErr error + +readloop: + for { + msg, err := pgConn.receiveMessage() + if err != nil { + pgConn.asyncClose() + return nil, normalizeTimeoutError(ctx, err) + } + + switch msg := msg.(type) { + case *pgproto3.ParameterDescription: + psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs)) + copy(psd.ParamOIDs, msg.ParameterOIDs) + case *pgproto3.RowDescription: + psd.Fields = pgConn.convertRowDescription(nil, msg) + case *pgproto3.ErrorResponse: + parseErr = ErrorResponseToPgError(msg) + case *pgproto3.ReadyForQuery: + break readloop + } + } + + if parseErr != nil { + return nil, parseErr + } + return psd, nil +} + +// ErrorResponseToPgError converts a wire protocol error message to a *PgError. +func ErrorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { + return &PgError{ + Severity: msg.Severity, + Code: string(msg.Code), + Message: string(msg.Message), + Detail: string(msg.Detail), + Hint: msg.Hint, + Position: msg.Position, + InternalPosition: msg.InternalPosition, + InternalQuery: string(msg.InternalQuery), + Where: string(msg.Where), + SchemaName: string(msg.SchemaName), + TableName: string(msg.TableName), + ColumnName: string(msg.ColumnName), + DataTypeName: string(msg.DataTypeName), + ConstraintName: msg.ConstraintName, + File: string(msg.File), + Line: msg.Line, + Routine: string(msg.Routine), + } +} + +func noticeResponseToNotice(msg *pgproto3.NoticeResponse) *Notice { + pgerr := ErrorResponseToPgError((*pgproto3.ErrorResponse)(msg)) + return (*Notice)(pgerr) +} + +// CancelRequest sends a cancel request to the PostgreSQL server. It returns an error if unable to deliver the cancel +// request, but lack of an error does not ensure that the query was canceled. As specified in the documentation, there +// is no way to be sure a query was canceled. See https://www.postgresql.org/docs/11/protocol-flow.html#id-1.10.5.7.9 +func (pgConn *PgConn) CancelRequest(ctx context.Context) error { + // Open a cancellation request to the same server. The address is taken from the net.Conn directly instead of reusing + // the connection config. This is important in high availability configurations where fallback connections may be + // specified or DNS may be used to load balance. + serverAddr := pgConn.conn.RemoteAddr() + cancelConn, err := pgConn.config.DialFunc(ctx, serverAddr.Network(), serverAddr.String()) + if err != nil { + return err + } + defer cancelConn.Close() + + if ctx != context.Background() { + contextWatcher := ctxwatch.NewContextWatcher( + func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, + func() { cancelConn.SetDeadline(time.Time{}) }, + ) + contextWatcher.Watch(ctx) + defer contextWatcher.Unwatch() + } + + buf := make([]byte, 16) + binary.BigEndian.PutUint32(buf[0:4], 16) + binary.BigEndian.PutUint32(buf[4:8], 80877102) + binary.BigEndian.PutUint32(buf[8:12], uint32(pgConn.pid)) + binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey)) + _, err = cancelConn.Write(buf) + if err != nil { + return err + } + + _, err = cancelConn.Read(buf) + if err != io.EOF { + return err + } + + return nil +} + +// WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not +// received. +func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { + if err := pgConn.lock(); err != nil { + return err + } + defer pgConn.unlock() + + if ctx != context.Background() { + select { + case <-ctx.Done(): + return newContextAlreadyDoneError(ctx) + default: + } + + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + } + + for { + msg, err := pgConn.receiveMessage() + if err != nil { + return normalizeTimeoutError(ctx, err) + } + + switch msg.(type) { + case *pgproto3.NotificationResponse: + return nil + } + } +} + +// Exec executes SQL via the PostgreSQL simple query protocol. SQL may contain multiple queries. Execution is +// implicitly wrapped in a transaction unless a transaction is already in progress or SQL contains transaction control +// statements. +// +// Prefer ExecParams unless executing arbitrary SQL that may contain multiple queries. +func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { + if err := pgConn.lock(); err != nil { + return &MultiResultReader{ + closed: true, + err: err, + } + } + + pgConn.multiResultReader = MultiResultReader{ + pgConn: pgConn, + ctx: ctx, + } + multiResult := &pgConn.multiResultReader + if ctx != context.Background() { + select { + case <-ctx.Done(): + multiResult.closed = true + multiResult.err = newContextAlreadyDoneError(ctx) + pgConn.unlock() + return multiResult + default: + } + pgConn.contextWatcher.Watch(ctx) + } + + pgConn.frontend.SendQuery(&pgproto3.Query{String: sql}) + err := pgConn.frontend.Flush() + if err != nil { + pgConn.asyncClose() + pgConn.contextWatcher.Unwatch() + multiResult.closed = true + multiResult.err = err + pgConn.unlock() + return multiResult + } + + return multiResult +} + +// ExecParams executes a command via the PostgreSQL extended query protocol. +// +// sql is a SQL command string. It may only contain one query. Parameter substitution is positional using $1, $2, $3, +// etc. +// +// paramValues are the parameter values. It must be encoded in the format given by paramFormats. +// +// paramOIDs is a slice of data type OIDs for paramValues. If paramOIDs is nil, the server will infer the data type for +// all parameters. Any paramOID element that is 0 that will cause the server to infer the data type for that parameter. +// ExecParams will panic if len(paramOIDs) is not 0, 1, or len(paramValues). +// +// paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or +// binary format. If paramFormats is nil all params are text format. ExecParams will panic if +// len(paramFormats) is not 0, 1, or len(paramValues). +// +// resultFormats is a slice of format codes determining for each result column whether it is encoded in text or +// binary format. If resultFormats is nil all results will be in text format. +// +// ResultReader must be closed before PgConn can be used again. +func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) *ResultReader { + result := pgConn.execExtendedPrefix(ctx, paramValues) + if result.closed { + return result + } + + pgConn.frontend.SendParse(&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}) + pgConn.frontend.SendBind(&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) + + pgConn.execExtendedSuffix(result) + + return result +} + +// ExecPrepared enqueues the execution of a prepared statement via the PostgreSQL extended query protocol. +// +// paramValues are the parameter values. It must be encoded in the format given by paramFormats. +// +// paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or +// binary format. If paramFormats is nil all params are text format. ExecPrepared will panic if +// len(paramFormats) is not 0, 1, or len(paramValues). +// +// resultFormats is a slice of format codes determining for each result column whether it is encoded in text or +// binary format. If resultFormats is nil all results will be in text format. +// +// ResultReader must be closed before PgConn can be used again. +func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) *ResultReader { + result := pgConn.execExtendedPrefix(ctx, paramValues) + if result.closed { + return result + } + + pgConn.frontend.SendBind(&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) + + pgConn.execExtendedSuffix(result) + + return result +} + +func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]byte) *ResultReader { + pgConn.resultReader = ResultReader{ + pgConn: pgConn, + ctx: ctx, + } + result := &pgConn.resultReader + + if err := pgConn.lock(); err != nil { + result.concludeCommand(CommandTag{}, err) + result.closed = true + return result + } + + if len(paramValues) > math.MaxUint16 { + result.concludeCommand(CommandTag{}, fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) + result.closed = true + pgConn.unlock() + return result + } + + if ctx != context.Background() { + select { + case <-ctx.Done(): + result.concludeCommand(CommandTag{}, newContextAlreadyDoneError(ctx)) + result.closed = true + pgConn.unlock() + return result + default: + } + pgConn.contextWatcher.Watch(ctx) + } + + return result +} + +func (pgConn *PgConn) execExtendedSuffix(result *ResultReader) { + pgConn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'}) + pgConn.frontend.SendExecute(&pgproto3.Execute{}) + pgConn.frontend.SendSync(&pgproto3.Sync{}) + + err := pgConn.frontend.Flush() + if err != nil { + pgConn.asyncClose() + result.concludeCommand(CommandTag{}, err) + pgConn.contextWatcher.Unwatch() + result.closed = true + pgConn.unlock() + return + } + + result.readUntilRowDescription() +} + +// CopyTo executes the copy command sql and copies the results to w. +func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) { + if err := pgConn.lock(); err != nil { + return CommandTag{}, err + } + + if ctx != context.Background() { + select { + case <-ctx.Done(): + pgConn.unlock() + return CommandTag{}, newContextAlreadyDoneError(ctx) + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + } + + // Send copy to command + pgConn.frontend.SendQuery(&pgproto3.Query{String: sql}) + + err := pgConn.frontend.Flush() + if err != nil { + pgConn.asyncClose() + pgConn.unlock() + return CommandTag{}, err + } + + // Read results + var commandTag CommandTag + var pgErr error + for { + msg, err := pgConn.receiveMessage() + if err != nil { + pgConn.asyncClose() + return CommandTag{}, normalizeTimeoutError(ctx, err) + } + + switch msg := msg.(type) { + case *pgproto3.CopyDone: + case *pgproto3.CopyData: + _, err := w.Write(msg.Data) + if err != nil { + pgConn.asyncClose() + return CommandTag{}, err + } + case *pgproto3.ReadyForQuery: + pgConn.unlock() + return commandTag, pgErr + case *pgproto3.CommandComplete: + commandTag = pgConn.makeCommandTag(msg.CommandTag) + case *pgproto3.ErrorResponse: + pgErr = ErrorResponseToPgError(msg) + } + } +} + +// CopyFrom executes the copy command sql and copies all of r to the PostgreSQL server. +// +// Note: context cancellation will only interrupt operations on the underlying PostgreSQL network connection. Reads on r +// could still block. +func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) { + if err := pgConn.lock(); err != nil { + return CommandTag{}, err + } + defer pgConn.unlock() + + if ctx != context.Background() { + select { + case <-ctx.Done(): + return CommandTag{}, newContextAlreadyDoneError(ctx) + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + } + + // Send copy to command + pgConn.frontend.SendQuery(&pgproto3.Query{String: sql}) + err := pgConn.frontend.Flush() + if err != nil { + pgConn.asyncClose() + return CommandTag{}, err + } + + err = pgConn.conn.SetReadDeadline(nbconn.NonBlockingDeadline) + if err != nil { + pgConn.asyncClose() + return CommandTag{}, err + } + nonblocking := true + defer func() { + if nonblocking { + pgConn.conn.SetReadDeadline(time.Time{}) + } + }() + + buf := iobufpool.Get(65536) + defer iobufpool.Put(buf) + buf[0] = 'd' + + var readErr, pgErr error + for pgErr == nil { + // Read chunk from r. + var n int + n, readErr = r.Read(buf[5:cap(buf)]) + + // Send chunk to PostgreSQL. + if n > 0 { + buf = buf[0 : n+5] + pgio.SetInt32(buf[1:], int32(n+4)) + + writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(buf) + if writeErr != nil { + pgConn.asyncClose() + return CommandTag{}, err + } + } + + // Abort loop if there was a read error. + if readErr != nil { + break + } + + // Read messages until error or none available. + for pgErr == nil { + msg, err := pgConn.receiveMessage() + if err != nil { + if errors.Is(err, nbconn.ErrWouldBlock) { + break + } + pgConn.asyncClose() + return CommandTag{}, normalizeTimeoutError(ctx, err) + } + + switch msg := msg.(type) { + case *pgproto3.ErrorResponse: + pgErr = ErrorResponseToPgError(msg) + break + } + } + } + + err = pgConn.conn.SetReadDeadline(time.Time{}) + if err != nil { + pgConn.asyncClose() + return CommandTag{}, err + } + nonblocking = false + + if readErr == io.EOF || pgErr != nil { + pgConn.frontend.Send(&pgproto3.CopyDone{}) + } else { + pgConn.frontend.Send(&pgproto3.CopyFail{Message: readErr.Error()}) + } + err = pgConn.frontend.Flush() + if err != nil { + pgConn.asyncClose() + return CommandTag{}, err + } + + // Read results + var commandTag CommandTag + for { + msg, err := pgConn.receiveMessage() + if err != nil { + pgConn.asyncClose() + return CommandTag{}, normalizeTimeoutError(ctx, err) + } + + switch msg := msg.(type) { + case *pgproto3.ReadyForQuery: + return commandTag, pgErr + case *pgproto3.CommandComplete: + commandTag = pgConn.makeCommandTag(msg.CommandTag) + case *pgproto3.ErrorResponse: + pgErr = ErrorResponseToPgError(msg) + } + } +} + +// MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch. +type MultiResultReader struct { + pgConn *PgConn + ctx context.Context + pipeline *Pipeline + + rr *ResultReader + + closed bool + err error +} + +// ReadAll reads all available results. Calling ReadAll is mutually exclusive with all other MultiResultReader methods. +func (mrr *MultiResultReader) ReadAll() ([]*Result, error) { + var results []*Result + + for mrr.NextResult() { + results = append(results, mrr.ResultReader().Read()) + } + err := mrr.Close() + + return results, err +} + +func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) { + msg, err := mrr.pgConn.receiveMessage() + + if err != nil { + mrr.pgConn.contextWatcher.Unwatch() + mrr.err = normalizeTimeoutError(mrr.ctx, err) + mrr.closed = true + mrr.pgConn.asyncClose() + return nil, mrr.err + } + + switch msg := msg.(type) { + case *pgproto3.ReadyForQuery: + mrr.closed = true + if mrr.pipeline != nil { + mrr.pipeline.expectedReadyForQueryCount-- + } else { + mrr.pgConn.contextWatcher.Unwatch() + mrr.pgConn.unlock() + } + case *pgproto3.ErrorResponse: + mrr.err = ErrorResponseToPgError(msg) + } + + return msg, nil +} + +// NextResult returns advances the MultiResultReader to the next result and returns true if a result is available. +func (mrr *MultiResultReader) NextResult() bool { + for !mrr.closed && mrr.err == nil { + msg, err := mrr.receiveMessage() + if err != nil { + return false + } + + switch msg := msg.(type) { + case *pgproto3.RowDescription: + mrr.pgConn.resultReader = ResultReader{ + pgConn: mrr.pgConn, + multiResultReader: mrr, + ctx: mrr.ctx, + fieldDescriptions: mrr.pgConn.convertRowDescription(mrr.pgConn.fieldDescriptions[:], msg), + } + + mrr.rr = &mrr.pgConn.resultReader + return true + case *pgproto3.CommandComplete: + mrr.pgConn.resultReader = ResultReader{ + commandTag: mrr.pgConn.makeCommandTag(msg.CommandTag), + commandConcluded: true, + closed: true, + } + mrr.rr = &mrr.pgConn.resultReader + return true + case *pgproto3.EmptyQueryResponse: + return false + } + } + + return false +} + +// ResultReader returns the current ResultReader. +func (mrr *MultiResultReader) ResultReader() *ResultReader { + return mrr.rr +} + +// Close closes the MultiResultReader and returns the first error that occurred during the MultiResultReader's use. +func (mrr *MultiResultReader) Close() error { + for !mrr.closed { + _, err := mrr.receiveMessage() + if err != nil { + return mrr.err + } + } + + return mrr.err +} + +// ResultReader is a reader for the result of a single query. +type ResultReader struct { + pgConn *PgConn + multiResultReader *MultiResultReader + pipeline *Pipeline + ctx context.Context + + fieldDescriptions []FieldDescription + rowValues [][]byte + commandTag CommandTag + commandConcluded bool + closed bool + err error +} + +// Result is the saved query response that is returned by calling Read on a ResultReader. +type Result struct { + FieldDescriptions []FieldDescription + Rows [][][]byte + CommandTag CommandTag + Err error +} + +// Read saves the query response to a Result. +func (rr *ResultReader) Read() *Result { + br := &Result{} + + for rr.NextRow() { + if br.FieldDescriptions == nil { + br.FieldDescriptions = make([]FieldDescription, len(rr.FieldDescriptions())) + copy(br.FieldDescriptions, rr.FieldDescriptions()) + } + + values := rr.Values() + row := make([][]byte, len(values)) + for i := range row { + row[i] = make([]byte, len(values[i])) + copy(row[i], values[i]) + } + br.Rows = append(br.Rows, row) + } + + br.CommandTag, br.Err = rr.Close() + + return br +} + +// NextRow advances the ResultReader to the next row and returns true if a row is available. +func (rr *ResultReader) NextRow() bool { + for !rr.commandConcluded { + msg, err := rr.receiveMessage() + if err != nil { + return false + } + + switch msg := msg.(type) { + case *pgproto3.DataRow: + rr.rowValues = msg.Values + return true + } + } + + return false +} + +// FieldDescriptions returns the field descriptions for the current result set. The returned slice is only valid until +// the ResultReader is closed. +func (rr *ResultReader) FieldDescriptions() []FieldDescription { + return rr.fieldDescriptions +} + +// Values returns the current row data. NextRow must have been previously been called. The returned [][]byte is only +// valid until the next NextRow call or the ResultReader is closed. +func (rr *ResultReader) Values() [][]byte { + return rr.rowValues +} + +// Close consumes any remaining result data and returns the command tag or +// error. +func (rr *ResultReader) Close() (CommandTag, error) { + if rr.closed { + return rr.commandTag, rr.err + } + rr.closed = true + + for !rr.commandConcluded { + _, err := rr.receiveMessage() + if err != nil { + return CommandTag{}, rr.err + } + } + + if rr.multiResultReader == nil && rr.pipeline == nil { + for { + msg, err := rr.receiveMessage() + if err != nil { + return CommandTag{}, rr.err + } + + switch msg := msg.(type) { + // Detect a deferred constraint violation where the ErrorResponse is sent after CommandComplete. + case *pgproto3.ErrorResponse: + rr.err = ErrorResponseToPgError(msg) + case *pgproto3.ReadyForQuery: + rr.pgConn.contextWatcher.Unwatch() + rr.pgConn.unlock() + return rr.commandTag, rr.err + } + } + } + + return rr.commandTag, rr.err +} + +// readUntilRowDescription ensures the ResultReader's fieldDescriptions are loaded. It does not return an error as any +// error will be stored in the ResultReader. +func (rr *ResultReader) readUntilRowDescription() { + for !rr.commandConcluded { + // Peek before receive to avoid consuming a DataRow if the result set does not include a RowDescription method. + // This should never happen under normal pgconn usage, but it is possible if SendBytes and ReceiveResults are + // manually used to construct a query that does not issue a describe statement. + msg, _ := rr.pgConn.peekMessage() + if _, ok := msg.(*pgproto3.DataRow); ok { + return + } + + // Consume the message + msg, _ = rr.receiveMessage() + if _, ok := msg.(*pgproto3.RowDescription); ok { + return + } + } +} + +func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error) { + if rr.multiResultReader == nil { + msg, err = rr.pgConn.receiveMessage() + } else { + msg, err = rr.multiResultReader.receiveMessage() + } + + if err != nil { + err = normalizeTimeoutError(rr.ctx, err) + rr.concludeCommand(CommandTag{}, err) + rr.pgConn.contextWatcher.Unwatch() + rr.closed = true + if rr.multiResultReader == nil { + rr.pgConn.asyncClose() + } + + return nil, rr.err + } + + switch msg := msg.(type) { + case *pgproto3.RowDescription: + rr.fieldDescriptions = rr.pgConn.convertRowDescription(rr.pgConn.fieldDescriptions[:], msg) + case *pgproto3.CommandComplete: + rr.concludeCommand(rr.pgConn.makeCommandTag(msg.CommandTag), nil) + case *pgproto3.EmptyQueryResponse: + rr.concludeCommand(CommandTag{}, nil) + case *pgproto3.ErrorResponse: + rr.concludeCommand(CommandTag{}, ErrorResponseToPgError(msg)) + } + + return msg, nil +} + +func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) { + // Keep the first error that is recorded. Store the error before checking if the command is already concluded to + // allow for receiving an error after CommandComplete but before ReadyForQuery. + if err != nil && rr.err == nil { + rr.err = err + } + + if rr.commandConcluded { + return + } + + rr.commandTag = commandTag + rr.rowValues = nil + rr.commandConcluded = true +} + +// Batch is a collection of queries that can be sent to the PostgreSQL server in a single round-trip. +type Batch struct { + buf []byte +} + +// ExecParams appends an ExecParams command to the batch. See PgConn.ExecParams for parameter descriptions. +func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) { + batch.buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf) + batch.ExecPrepared("", paramValues, paramFormats, resultFormats) +} + +// ExecPrepared appends an ExecPrepared e command to the batch. See PgConn.ExecPrepared for parameter descriptions. +func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) { + batch.buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf) + batch.buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf) + batch.buf = (&pgproto3.Execute{}).Encode(batch.buf) +} + +// ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a +// transaction is already in progress or SQL contains transaction control statements. This is a simpler way of executing +// multiple queries in a single round trip than using pipeline mode. +func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader { + if err := pgConn.lock(); err != nil { + return &MultiResultReader{ + closed: true, + err: err, + } + } + + pgConn.multiResultReader = MultiResultReader{ + pgConn: pgConn, + ctx: ctx, + } + multiResult := &pgConn.multiResultReader + + if ctx != context.Background() { + select { + case <-ctx.Done(): + multiResult.closed = true + multiResult.err = newContextAlreadyDoneError(ctx) + pgConn.unlock() + return multiResult + default: + } + pgConn.contextWatcher.Watch(ctx) + } + + batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) + + _, err := pgConn.conn.Write(batch.buf) + if err != nil { + multiResult.closed = true + multiResult.err = err + pgConn.unlock() + return multiResult + } + + return multiResult +} + +// EscapeString escapes a string such that it can safely be interpolated into a SQL command string. It does not include +// the surrounding single quotes. +// +// The current implementation requires that standard_conforming_strings=on and client_encoding="UTF8". If these +// conditions are not met an error will be returned. It is possible these restrictions will be lifted in the future. +func (pgConn *PgConn) EscapeString(s string) (string, error) { + if pgConn.ParameterStatus("standard_conforming_strings") != "on" { + return "", errors.New("EscapeString must be run with standard_conforming_strings=on") + } + + if pgConn.ParameterStatus("client_encoding") != "UTF8" { + return "", errors.New("EscapeString must be run with client_encoding=UTF8") + } + + return strings.Replace(s, "'", "''", -1), nil +} + +// CheckConn checks the underlying connection without writing any bytes. This is currently implemented by reading and +// buffering until the read would block or an error occurs. This can be used to check if the server has closed the +// connection. If this is done immediately before sending a query it reduces the chances a query will be sent that fails +// without the client knowing whether the server received it or not. +func (pgConn *PgConn) CheckConn() error { + err := pgConn.conn.BufferReadUntilBlock() + if err != nil && !errors.Is(err, nbconn.ErrWouldBlock) { + return err + } + return nil +} + +// makeCommandTag makes a CommandTag. It does not retain a reference to buf or buf's underlying memory. +func (pgConn *PgConn) makeCommandTag(buf []byte) CommandTag { + return CommandTag{s: string(buf)} +} + +// HijackedConn is the result of hijacking a connection. +// +// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning +// compatibility. +type HijackedConn struct { + Conn nbconn.Conn // the non-blocking wrapper of the underlying TCP or unix domain socket connection + PID uint32 // backend pid + SecretKey uint32 // key to use to send a cancel query message to the server + ParameterStatuses map[string]string // parameters that have been reported by the server + TxStatus byte + Frontend *pgproto3.Frontend + Config *Config +} + +// Hijack extracts the internal connection data. pgConn must be in an idle state. pgConn is unusable after hijacking. +// Hijacking is typically only useful when using pgconn to establish a connection, but taking complete control of the +// raw connection after that (e.g. a load balancer or proxy). +// +// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning +// compatibility. +func (pgConn *PgConn) Hijack() (*HijackedConn, error) { + if err := pgConn.lock(); err != nil { + return nil, err + } + pgConn.status = connStatusClosed + + return &HijackedConn{ + Conn: pgConn.conn, + PID: pgConn.pid, + SecretKey: pgConn.secretKey, + ParameterStatuses: pgConn.parameterStatuses, + TxStatus: pgConn.txStatus, + Frontend: pgConn.frontend, + Config: pgConn.config, + }, nil +} + +// Construct created a PgConn from an already established connection to a PostgreSQL server. This is the inverse of +// PgConn.Hijack. The connection must be in an idle state. +// +// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning +// compatibility. +func Construct(hc *HijackedConn) (*PgConn, error) { + pgConn := &PgConn{ + conn: hc.Conn, + pid: hc.PID, + secretKey: hc.SecretKey, + parameterStatuses: hc.ParameterStatuses, + txStatus: hc.TxStatus, + frontend: hc.Frontend, + config: hc.Config, + + status: connStatusIdle, + + cleanupDone: make(chan struct{}), + } + + pgConn.contextWatcher = newContextWatcher(pgConn.conn) + + return pgConn, nil +} + +// Pipeline represents a connection in pipeline mode. +// +// SendPrepare, SendQueryParams, and SendQueryPrepared queue requests to the server. These requests are not written until +// pipeline is flushed by Flush or Sync. Sync must be called after the last request is queued. Requests between +// synchronization points are implicitly transactional unless explicit transaction control statements have been issued. +// +// The context the pipeline was started with is in effect for the entire life of the Pipeline. +// +// For a deeper understanding of pipeline mode see the PostgreSQL documentation for the extended query protocol +// (https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY) and the libpq pipeline mode +// (https://www.postgresql.org/docs/current/libpq-pipeline-mode.html). +type Pipeline struct { + conn *PgConn + ctx context.Context + + expectedReadyForQueryCount int + pendingSync bool + + err error + closed bool +} + +// PipelineSync is returned by GetResults when a ReadyForQuery message is received. +type PipelineSync struct{} + +// CloseComplete is returned by GetResults when a CloseComplete message is received. +type CloseComplete struct{} + +// StartPipeline switches the connection to pipeline mode and returns a *Pipeline. In pipeline mode requests can be sent +// to the server without waiting for a response. Close must be called on the returned *Pipeline to return the connection +// to normal mode. While in pipeline mode, no methods that communicate with the server may be called except +// CancelRequest and Close. ctx is in effect for entire life of the *Pipeline. +// +// Prefer ExecBatch when only sending one group of queries at once. +func (pgConn *PgConn) StartPipeline(ctx context.Context) *Pipeline { + if err := pgConn.lock(); err != nil { + return &Pipeline{ + closed: true, + err: err, + } + } + + pgConn.pipeline = Pipeline{ + conn: pgConn, + ctx: ctx, + } + pipeline := &pgConn.pipeline + + if ctx != context.Background() { + select { + case <-ctx.Done(): + pipeline.closed = true + pipeline.err = newContextAlreadyDoneError(ctx) + pgConn.unlock() + return pipeline + default: + } + pgConn.contextWatcher.Watch(ctx) + } + + return pipeline +} + +// SendPrepare is the pipeline version of *PgConn.Prepare. +func (p *Pipeline) SendPrepare(name, sql string, paramOIDs []uint32) { + if p.closed { + return + } + p.pendingSync = true + + p.conn.frontend.SendParse(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}) + p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: name}) +} + +// SendDeallocate deallocates a prepared statement. +func (p *Pipeline) SendDeallocate(name string) { + if p.closed { + return + } + p.pendingSync = true + + p.conn.frontend.SendClose(&pgproto3.Close{ObjectType: 'S', Name: name}) +} + +// SendQueryParams is the pipeline version of *PgConn.QueryParams. +func (p *Pipeline) SendQueryParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) { + if p.closed { + return + } + p.pendingSync = true + + p.conn.frontend.SendParse(&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}) + p.conn.frontend.SendBind(&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) + p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'}) + p.conn.frontend.SendExecute(&pgproto3.Execute{}) +} + +// SendQueryPrepared is the pipeline version of *PgConn.QueryPrepared. +func (p *Pipeline) SendQueryPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) { + if p.closed { + return + } + p.pendingSync = true + + p.conn.frontend.SendBind(&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) + p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'}) + p.conn.frontend.SendExecute(&pgproto3.Execute{}) +} + +// Flush flushes the queued requests without establishing a synchronization point. +func (p *Pipeline) Flush() error { + if p.closed { + if p.err != nil { + return p.err + } + return errors.New("pipeline closed") + } + + err := p.conn.frontend.Flush() + if err != nil { + err = normalizeTimeoutError(p.ctx, err) + + p.conn.asyncClose() + + p.conn.contextWatcher.Unwatch() + p.conn.unlock() + p.closed = true + p.err = err + return err + } + + return nil +} + +// Sync establishes a synchronization point and flushes the queued requests. +func (p *Pipeline) Sync() error { + p.conn.frontend.SendSync(&pgproto3.Sync{}) + err := p.Flush() + if err != nil { + return err + } + + p.pendingSync = false + p.expectedReadyForQueryCount++ + + return nil +} + +// GetResults gets the next results. If results are present, results may be a *ResultReader, *StatementDescription, or +// *PipelineSync. If an ErrorResponse is received from the server, results will be nil and err will be a *PgError. If no +// results are available, results and err will both be nil. +func (p *Pipeline) GetResults() (results any, err error) { + if p.expectedReadyForQueryCount == 0 { + return nil, nil + } + + for { + msg, err := p.conn.receiveMessage() + if err != nil { + return nil, err + } + + switch msg := msg.(type) { + case *pgproto3.RowDescription: + p.conn.resultReader = ResultReader{ + pgConn: p.conn, + pipeline: p, + ctx: p.ctx, + fieldDescriptions: p.conn.convertRowDescription(p.conn.fieldDescriptions[:], msg), + } + return &p.conn.resultReader, nil + case *pgproto3.CommandComplete: + p.conn.resultReader = ResultReader{ + commandTag: p.conn.makeCommandTag(msg.CommandTag), + commandConcluded: true, + closed: true, + } + return &p.conn.resultReader, nil + case *pgproto3.ParseComplete: + peekedMsg, err := p.conn.peekMessage() + if err != nil { + return nil, err + } + if _, ok := peekedMsg.(*pgproto3.ParameterDescription); ok { + return p.getResultsPrepare() + } + case *pgproto3.CloseComplete: + return &CloseComplete{}, nil + case *pgproto3.ReadyForQuery: + p.expectedReadyForQueryCount-- + return &PipelineSync{}, nil + case *pgproto3.ErrorResponse: + pgErr := ErrorResponseToPgError(msg) + return nil, pgErr + } + + } + +} + +func (p *Pipeline) getResultsPrepare() (*StatementDescription, error) { + psd := &StatementDescription{} + + for { + msg, err := p.conn.receiveMessage() + if err != nil { + p.conn.asyncClose() + return nil, normalizeTimeoutError(p.ctx, err) + } + + switch msg := msg.(type) { + case *pgproto3.ParameterDescription: + psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs)) + copy(psd.ParamOIDs, msg.ParameterOIDs) + case *pgproto3.RowDescription: + psd.Fields = p.conn.convertRowDescription(nil, msg) + return psd, nil + + // NoData is returned instead of RowDescription when there is no expected result. e.g. An INSERT without a RETURNING + // clause. + case *pgproto3.NoData: + return psd, nil + + // These should never happen here. But don't take chances that could lead to a deadlock. + case *pgproto3.ErrorResponse: + pgErr := ErrorResponseToPgError(msg) + return nil, pgErr + case *pgproto3.CommandComplete: + p.conn.asyncClose() + return nil, errors.New("BUG: received CommandComplete while handling Describe") + case *pgproto3.ReadyForQuery: + p.conn.asyncClose() + return nil, errors.New("BUG: received ReadyForQuery while handling Describe") + } + } +} + +// Close closes the pipeline and returns the connection to normal mode. +func (p *Pipeline) Close() error { + if p.closed { + return p.err + } + p.closed = true + + if p.pendingSync { + p.conn.asyncClose() + p.err = errors.New("pipeline has unsynced requests") + p.conn.contextWatcher.Unwatch() + p.conn.unlock() + + return p.err + } + + for p.expectedReadyForQueryCount > 0 { + _, err := p.GetResults() + if err != nil { + p.err = err + var pgErr *PgError + if !errors.As(err, &pgErr) { + p.conn.asyncClose() + break + } + } + } + + p.conn.contextWatcher.Unwatch() + p.conn.unlock() + + return p.err +} diff --git a/pgconn/pgconn_private_test.go b/pgconn/pgconn_private_test.go new file mode 100644 index 00000000..5659bc9e --- /dev/null +++ b/pgconn/pgconn_private_test.go @@ -0,0 +1,41 @@ +package pgconn + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCommandTag(t *testing.T) { + t.Parallel() + + var tests = []struct { + commandTag CommandTag + rowsAffected int64 + isInsert bool + isUpdate bool + isDelete bool + isSelect bool + }{ + {commandTag: CommandTag{s: "INSERT 0 5"}, rowsAffected: 5, isInsert: true}, + {commandTag: CommandTag{s: "UPDATE 0"}, rowsAffected: 0, isUpdate: true}, + {commandTag: CommandTag{s: "UPDATE 1"}, rowsAffected: 1, isUpdate: true}, + {commandTag: CommandTag{s: "DELETE 0"}, rowsAffected: 0, isDelete: true}, + {commandTag: CommandTag{s: "DELETE 1"}, rowsAffected: 1, isDelete: true}, + {commandTag: CommandTag{s: "DELETE 1234567890"}, rowsAffected: 1234567890, isDelete: true}, + {commandTag: CommandTag{s: "SELECT 1"}, rowsAffected: 1, isSelect: true}, + {commandTag: CommandTag{s: "SELECT 99999999999"}, rowsAffected: 99999999999, isSelect: true}, + {commandTag: CommandTag{s: "CREATE TABLE"}, rowsAffected: 0}, + {commandTag: CommandTag{s: "ALTER TABLE"}, rowsAffected: 0}, + {commandTag: CommandTag{s: "DROP TABLE"}, rowsAffected: 0}, + } + + for i, tt := range tests { + ct := tt.commandTag + assert.Equalf(t, tt.rowsAffected, ct.RowsAffected(), "%d. %v", i, tt.commandTag) + assert.Equalf(t, tt.isInsert, ct.Insert(), "%d. %v", i, tt.commandTag) + assert.Equalf(t, tt.isUpdate, ct.Update(), "%d. %v", i, tt.commandTag) + assert.Equalf(t, tt.isDelete, ct.Delete(), "%d. %v", i, tt.commandTag) + assert.Equalf(t, tt.isSelect, ct.Select(), "%d. %v", i, tt.commandTag) + } +} diff --git a/pgconn/pgconn_stress_test.go b/pgconn/pgconn_stress_test.go new file mode 100644 index 00000000..3d72964f --- /dev/null +++ b/pgconn/pgconn_stress_test.go @@ -0,0 +1,90 @@ +package pgconn_test + +import ( + "context" + "math/rand" + "os" + "runtime" + "strconv" + "testing" + + "github.com/jackc/pgx/v5/pgconn" + + "github.com/stretchr/testify/require" +) + +func TestConnStress(t *testing.T) { + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + actionCount := 10000 + if s := os.Getenv("PGX_TEST_STRESS_FACTOR"); s != "" { + stressFactor, err := strconv.ParseInt(s, 10, 64) + require.Nil(t, err, "Failed to parse PGX_TEST_STRESS_FACTOR") + actionCount *= int(stressFactor) + } + + setupStressDB(t, pgConn) + + actions := []struct { + name string + fn func(*pgconn.PgConn) error + }{ + {"Exec Select", stressExecSelect}, + {"ExecParams Select", stressExecParamsSelect}, + {"Batch", stressBatch}, + } + + for i := 0; i < actionCount; i++ { + action := actions[rand.Intn(len(actions))] + err := action.fn(pgConn) + require.Nilf(t, err, "%d: %s", i, action.name) + } + + // Each call with a context starts a goroutine. Ensure they are cleaned up when context is not canceled. + numGoroutine := runtime.NumGoroutine() + require.Truef(t, numGoroutine < 1000, "goroutines appear to be orphaned: %d in process", numGoroutine) +} + +func setupStressDB(t *testing.T, pgConn *pgconn.PgConn) { + _, err := pgConn.Exec(context.Background(), ` + create temporary table widgets( + id serial primary key, + name varchar not null, + description text, + creation_time timestamptz default now() + ); + + insert into widgets(name, description) values + ('Foo', 'bar'), + ('baz', 'Something really long Something really long Something really long Something really long Something really long'), + ('a', 'b')`).ReadAll() + require.NoError(t, err) +} + +func stressExecSelect(pgConn *pgconn.PgConn) error { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, err := pgConn.Exec(ctx, "select * from widgets").ReadAll() + return err +} + +func stressExecParamsSelect(pgConn *pgconn.PgConn) error { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + result := pgConn.ExecParams(ctx, "select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil).Read() + return result.Err +} + +func stressBatch(pgConn *pgconn.PgConn) error { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + batch := &pgconn.Batch{} + + batch.ExecParams("select * from widgets", nil, nil, nil, nil) + batch.ExecParams("select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil) + _, err := pgConn.ExecBatch(ctx, batch).ReadAll() + return err +} diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go new file mode 100644 index 00000000..5b6ca284 --- /dev/null +++ b/pgconn/pgconn_test.go @@ -0,0 +1,2817 @@ +package pgconn_test + +import ( + "bytes" + "compress/gzip" + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "io/ioutil" + "log" + "math" + "net" + "os" + "strconv" + "strings" + "testing" + "time" + + "github.com/jackc/pgx/v5/internal/pgmock" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgproto3" + "github.com/jackc/pgx/v5/pgtype" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestConnect(t *testing.T) { + tests := []struct { + name string + env string + }{ + {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, + {"TCP", "PGX_TEST_TCP_CONN_STRING"}, + {"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"}, + {"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"}, + {"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + connString := os.Getenv(tt.env) + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", tt.env) + } + + conn, err := pgconn.Connect(context.Background(), connString) + require.NoError(t, err) + + closeConn(t, conn) + }) + } +} + +func TestConnectWithOptions(t *testing.T) { + tests := []struct { + name string + env string + }{ + {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, + {"TCP", "PGX_TEST_TCP_CONN_STRING"}, + {"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"}, + {"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"}, + {"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + connString := os.Getenv(tt.env) + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", tt.env) + } + var sslOptions pgconn.ParseConfigOptions + sslOptions.GetSSLPassword = GetSSLPassword + conn, err := pgconn.ConnectWithOptions(context.Background(), connString, sslOptions) + require.NoError(t, err) + + closeConn(t, conn) + }) + } +} + +// TestConnectTLS is separate from other connect tests because it has an additional test to ensure it really is a secure +// connection. +func TestConnectTLS(t *testing.T) { + t.Parallel() + + connString := os.Getenv("PGX_TEST_TLS_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING") + } + + var conn *pgconn.PgConn + var err error + + var sslOptions pgconn.ParseConfigOptions + sslOptions.GetSSLPassword = GetSSLPassword + config, err := pgconn.ParseConfigWithOptions(connString, sslOptions) + require.Nil(t, err) + + conn, err = pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + + result := conn.ExecParams(context.Background(), `select ssl from pg_stat_ssl where pg_backend_pid() = pid;`, nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + require.Len(t, result.Rows, 1) + require.Len(t, result.Rows[0], 1) + require.Equalf(t, "t", string(result.Rows[0][0]), "not a TLS connection") + + closeConn(t, conn) +} + +type pgmockWaitStep time.Duration + +func (s pgmockWaitStep) Step(*pgproto3.Backend) error { + time.Sleep(time.Duration(s)) + return nil +} + +func TestConnectTimeout(t *testing.T) { + t.Parallel() + tests := []struct { + name string + connect func(connStr string) error + }{ + { + name: "via context that times out", + connect: func(connStr string) error { + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50) + defer cancel() + _, err := pgconn.Connect(ctx, connStr) + return err + }, + }, + { + name: "via config ConnectTimeout", + connect: func(connStr string) error { + conf, err := pgconn.ParseConfig(connStr) + require.NoError(t, err) + conf.ConnectTimeout = time.Microsecond * 50 + _, err = pgconn.ConnectConfig(context.Background(), conf) + return err + }, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + script := &pgmock.Script{ + Steps: []pgmock.Step{ + pgmock.ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}), + pgmock.SendMessage(&pgproto3.AuthenticationOk{}), + pgmockWaitStep(time.Millisecond * 500), + pgmock.SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}), + pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), + }, + } + + ln, err := net.Listen("tcp", "127.0.0.1:") + require.NoError(t, err) + defer ln.Close() + + serverErrChan := make(chan error, 1) + go func() { + defer close(serverErrChan) + + conn, err := ln.Accept() + if err != nil { + serverErrChan <- err + return + } + defer conn.Close() + + err = conn.SetDeadline(time.Now().Add(time.Millisecond * 450)) + if err != nil { + serverErrChan <- err + return + } + + err = script.Run(pgproto3.NewBackend(conn, conn)) + if err != nil { + serverErrChan <- err + return + } + }() + + parts := strings.Split(ln.Addr().String(), ":") + host := parts[0] + port := parts[1] + connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port) + tooLate := time.Now().Add(time.Millisecond * 500) + + err = tt.connect(connStr) + require.True(t, pgconn.Timeout(err), err) + require.True(t, time.Now().Before(tooLate)) + }) + } +} + +func TestConnectTimeoutStuckOnTLSHandshake(t *testing.T) { + t.Parallel() + tests := []struct { + name string + connect func(connStr string) error + }{ + { + name: "via context that times out", + connect: func(connStr string) error { + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*10) + defer cancel() + _, err := pgconn.Connect(ctx, connStr) + return err + }, + }, + { + name: "via config ConnectTimeout", + connect: func(connStr string) error { + conf, err := pgconn.ParseConfig(connStr) + require.NoError(t, err) + conf.ConnectTimeout = time.Millisecond * 10 + _, err = pgconn.ConnectConfig(context.Background(), conf) + return err + }, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ln, err := net.Listen("tcp", "127.0.0.1:") + require.NoError(t, err) + defer ln.Close() + + serverErrChan := make(chan error) + defer close(serverErrChan) + go func() { + conn, err := ln.Accept() + if err != nil { + serverErrChan <- err + return + } + defer conn.Close() + + var buf []byte + _, err = conn.Read(buf) + if err != nil { + serverErrChan <- err + return + } + + // Sleeping to hang the TLS handshake. + time.Sleep(time.Minute) + }() + + parts := strings.Split(ln.Addr().String(), ":") + host := parts[0] + port := parts[1] + connStr := fmt.Sprintf("host=%s port=%s", host, port) + + errChan := make(chan error) + go func() { + err := tt.connect(connStr) + errChan <- err + }() + + select { + case err = <-errChan: + require.True(t, pgconn.Timeout(err), err) + case err = <-serverErrChan: + t.Fatalf("server failed with error: %s", err) + case <-time.After(time.Millisecond * 100): + t.Fatal("exceeded connection timeout without erroring out") + } + }) + } +} + +func TestConnectInvalidUser(t *testing.T) { + t.Parallel() + + connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") + } + + config, err := pgconn.ParseConfig(connString) + require.NoError(t, err) + + config.User = "pgxinvalidusertest" + + _, err = pgconn.ConnectConfig(context.Background(), config) + require.Error(t, err) + pgErr, ok := errors.Unwrap(err).(*pgconn.PgError) + if !ok { + t.Fatalf("Expected to receive a wrapped PgError, instead received: %v", err) + } + if pgErr.Code != "28000" && pgErr.Code != "28P01" { + t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr) + } +} + +func TestConnectWithConnectionRefused(t *testing.T) { + t.Parallel() + + // Presumably nothing is listening on 127.0.0.1:1 + conn, err := pgconn.Connect(context.Background(), "host=127.0.0.1 port=1") + if err == nil { + conn.Close(context.Background()) + t.Fatal("Expected error establishing connection to bad port") + } +} + +func TestConnectCustomDialer(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + dialed := false + config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { + dialed = true + return net.Dial(network, address) + } + + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + require.True(t, dialed) + closeConn(t, conn) +} + +func TestConnectCustomLookup(t *testing.T) { + t.Parallel() + + connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") + } + + config, err := pgconn.ParseConfig(connString) + require.NoError(t, err) + + looked := false + config.LookupFunc = func(ctx context.Context, host string) (addrs []string, err error) { + looked = true + return net.LookupHost(host) + } + + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + require.True(t, looked) + closeConn(t, conn) +} + +func TestConnectCustomLookupWithPort(t *testing.T) { + t.Parallel() + + connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") + } + + config, err := pgconn.ParseConfig(connString) + require.NoError(t, err) + + origPort := config.Port + // Chnage the config an invalid port so it will fail if used + config.Port = 0 + + looked := false + config.LookupFunc = func(ctx context.Context, host string) ([]string, error) { + looked = true + addrs, err := net.LookupHost(host) + if err != nil { + return nil, err + } + for i := range addrs { + addrs[i] = net.JoinHostPort(addrs[i], strconv.FormatUint(uint64(origPort), 10)) + } + return addrs, nil + } + + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + require.True(t, looked) + closeConn(t, conn) +} + +func TestConnectWithRuntimeParams(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + config.RuntimeParams = map[string]string{ + "application_name": "pgxtest", + "search_path": "myschema", + } + + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, conn) + + result := conn.ExecParams(context.Background(), "show application_name", nil, nil, nil, nil).Read() + require.Nil(t, result.Err) + assert.Equal(t, 1, len(result.Rows)) + assert.Equal(t, "pgxtest", string(result.Rows[0][0])) + + result = conn.ExecParams(context.Background(), "show search_path", nil, nil, nil, nil).Read() + require.Nil(t, result.Err) + assert.Equal(t, 1, len(result.Rows)) + assert.Equal(t, "myschema", string(result.Rows[0][0])) +} + +func TestConnectWithFallback(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + // Prepend current primary config to fallbacks + config.Fallbacks = append([]*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: config.Host, + Port: config.Port, + TLSConfig: config.TLSConfig, + }, + }, config.Fallbacks...) + + // Make primary config bad + config.Host = "localhost" + config.Port = 1 // presumably nothing listening here + + // Prepend bad first fallback + config.Fallbacks = append([]*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "localhost", + Port: 1, + TLSConfig: config.TLSConfig, + }, + }, config.Fallbacks...) + + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + closeConn(t, conn) +} + +func TestConnectWithValidateConnect(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + dialCount := 0 + config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { + dialCount++ + return net.Dial(network, address) + } + + acceptConnCount := 0 + config.ValidateConnect = func(ctx context.Context, conn *pgconn.PgConn) error { + acceptConnCount++ + if acceptConnCount < 2 { + return errors.New("reject first conn") + } + return nil + } + + // Append current primary config to fallbacks + config.Fallbacks = append(config.Fallbacks, &pgconn.FallbackConfig{ + Host: config.Host, + Port: config.Port, + TLSConfig: config.TLSConfig, + }) + + // Repeat fallbacks + config.Fallbacks = append(config.Fallbacks, config.Fallbacks...) + + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + closeConn(t, conn) + + assert.True(t, dialCount > 1) + assert.True(t, acceptConnCount > 1) +} + +func TestConnectWithValidateConnectTargetSessionAttrsReadWrite(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + config.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadWrite + config.RuntimeParams["default_transaction_read_only"] = "on" + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + conn, err := pgconn.ConnectConfig(ctx, config) + if !assert.NotNil(t, err) { + conn.Close(ctx) + } +} + +func TestConnectWithAfterConnect(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + config.AfterConnect = func(ctx context.Context, conn *pgconn.PgConn) error { + _, err := conn.Exec(ctx, "set search_path to foobar;").ReadAll() + return err + } + + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + + results, err := conn.Exec(context.Background(), "show search_path;").ReadAll() + require.NoError(t, err) + defer closeConn(t, conn) + + assert.Equal(t, []byte("foobar"), results[0].Rows[0][0]) +} + +func TestConnectConfigRequiresConfigFromParseConfig(t *testing.T) { + t.Parallel() + + config := &pgconn.Config{} + + require.PanicsWithValue(t, "config must be created by ParseConfig", func() { pgconn.ConnectConfig(context.Background(), config) }) +} + +func TestConnPrepareSyntaxError(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + psd, err := pgConn.Prepare(context.Background(), "ps1", "SYNTAX ERROR", nil) + require.Nil(t, psd) + require.NotNil(t, err) + + ensureConnValid(t, pgConn) +} + +func TestConnPrepareContextPrecanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + psd, err := pgConn.Prepare(ctx, "ps1", "select 1", nil) + assert.Nil(t, psd) + assert.Error(t, err) + assert.True(t, errors.Is(err, context.Canceled)) + assert.True(t, pgconn.SafeToRetry(err)) + + ensureConnValid(t, pgConn) +} + +func TestConnExec(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() + assert.NoError(t, err) + + assert.Len(t, results, 1) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", results[0].CommandTag.String()) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + + ensureConnValid(t, pgConn) +} + +func TestConnExecEmpty(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + multiResult := pgConn.Exec(context.Background(), ";") + + resultCount := 0 + for multiResult.NextResult() { + resultCount++ + multiResult.ResultReader().Close() + } + assert.Equal(t, 0, resultCount) + err = multiResult.Close() + assert.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestConnExecMultipleQueries(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + results, err := pgConn.Exec(context.Background(), "select 'Hello, world'; select 1").ReadAll() + assert.NoError(t, err) + + assert.Len(t, results, 2) + + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", results[0].CommandTag.String()) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + + assert.Nil(t, results[1].Err) + assert.Equal(t, "SELECT 1", results[1].CommandTag.String()) + assert.Len(t, results[1].Rows, 1) + assert.Equal(t, "1", string(results[1].Rows[0][0])) + + ensureConnValid(t, pgConn) +} + +func TestConnExecMultipleQueriesEagerFieldDescriptions(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + mrr := pgConn.Exec(context.Background(), "select 'Hello, world' as msg; select 1 as num") + + require.True(t, mrr.NextResult()) + require.Len(t, mrr.ResultReader().FieldDescriptions(), 1) + assert.Equal(t, "msg", mrr.ResultReader().FieldDescriptions()[0].Name) + _, err = mrr.ResultReader().Close() + require.NoError(t, err) + + require.True(t, mrr.NextResult()) + require.Len(t, mrr.ResultReader().FieldDescriptions(), 1) + assert.Equal(t, "num", mrr.ResultReader().FieldDescriptions()[0].Name) + _, err = mrr.ResultReader().Close() + require.NoError(t, err) + + require.False(t, mrr.NextResult()) + + require.NoError(t, mrr.Close()) + + ensureConnValid(t, pgConn) +} + +func TestConnExecMultipleQueriesError(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + results, err := pgConn.Exec(context.Background(), "select 1; select 1/0; select 1").ReadAll() + require.NotNil(t, err) + if pgErr, ok := err.(*pgconn.PgError); ok { + assert.Equal(t, "22012", pgErr.Code) + } else { + t.Errorf("unexpected error: %v", err) + } + + if pgConn.ParameterStatus("crdb_version") != "" { + // CockroachDB starts the second query result set and then sends the divide by zero error. + require.Len(t, results, 2) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "1", string(results[0].Rows[0][0])) + assert.Len(t, results[1].Rows, 0) + } else { + // PostgreSQL sends the divide by zero and never sends the second query result set. + require.Len(t, results, 1) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "1", string(results[0].Rows[0][0])) + } + + ensureConnValid(t, pgConn) +} + +func TestConnExecDeferredError(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") + } + + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); + + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + + _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() + assert.NoError(t, err) + + _, err = pgConn.Exec(context.Background(), `update t set n=n+1 where id='b' returning *`).ReadAll() + require.NotNil(t, err) + + var pgErr *pgconn.PgError + require.True(t, errors.As(err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) + + ensureConnValid(t, pgConn) +} + +func TestConnExecContextCanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + multiResult := pgConn.Exec(ctx, "select 'Hello, world', pg_sleep(1)") + + for multiResult.NextResult() { + } + err = multiResult.Close() + assert.True(t, pgconn.Timeout(err)) + assert.ErrorIs(t, err, context.DeadlineExceeded) + assert.True(t, pgConn.IsClosed()) + select { + case <-pgConn.CleanupDone(): + case <-time.After(5 * time.Second): + t.Fatal("Connection cleanup exceeded maximum time") + } +} + +func TestConnExecContextPrecanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll() + assert.Error(t, err) + assert.True(t, errors.Is(err, context.Canceled)) + assert.True(t, pgconn.SafeToRetry(err)) + + ensureConnValid(t, pgConn) +} + +func TestConnExecParams(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + result := pgConn.ExecParams(context.Background(), "select $1::text as msg", [][]byte{[]byte("Hello, world")}, nil, nil, nil) + require.Len(t, result.FieldDescriptions(), 1) + assert.Equal(t, "msg", result.FieldDescriptions()[0].Name) + + rowCount := 0 + for result.NextRow() { + rowCount += 1 + assert.Equal(t, "Hello, world", string(result.Values()[0])) + } + assert.Equal(t, 1, rowCount) + commandTag, err := result.Close() + assert.Equal(t, "SELECT 1", commandTag.String()) + assert.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestConnExecParamsDeferredError(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") + } + + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); + + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + + _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() + assert.NoError(t, err) + + result := pgConn.ExecParams(context.Background(), `update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil).Read() + require.NotNil(t, result.Err) + var pgErr *pgconn.PgError + require.True(t, errors.As(result.Err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) + + ensureConnValid(t, pgConn) +} + +func TestConnExecParamsMaxNumberOfParams(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + paramCount := math.MaxUint16 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") + + result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read() + require.NoError(t, result.Err) + require.Len(t, result.Rows, paramCount) + + ensureConnValid(t, pgConn) +} + +func TestConnExecParamsTooManyParams(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + paramCount := math.MaxUint16 + 1 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") + + result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read() + require.Error(t, result.Err) + require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) + + ensureConnValid(t, pgConn) +} + +func TestConnExecParamsCanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + result := pgConn.ExecParams(ctx, "select current_database(), pg_sleep(1)", nil, nil, nil, nil) + rowCount := 0 + for result.NextRow() { + rowCount += 1 + } + assert.Equal(t, 0, rowCount) + commandTag, err := result.Close() + assert.Equal(t, pgconn.CommandTag{}, commandTag) + assert.True(t, pgconn.Timeout(err)) + assert.ErrorIs(t, err, context.DeadlineExceeded) + + assert.True(t, pgConn.IsClosed()) + select { + case <-pgConn.CleanupDone(): + case <-time.After(5 * time.Second): + t.Fatal("Connection cleanup exceeded maximum time") + } +} + +func TestConnExecParamsPrecanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + result := pgConn.ExecParams(ctx, "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil).Read() + require.Error(t, result.Err) + assert.True(t, errors.Is(result.Err, context.Canceled)) + assert.True(t, pgconn.SafeToRetry(result.Err)) + + ensureConnValid(t, pgConn) +} + +func TestConnExecParamsEmptySQL(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + result := pgConn.ExecParams(ctx, "", nil, nil, nil, nil).Read() + assert.Equal(t, pgconn.CommandTag{}, result.CommandTag) + assert.Len(t, result.Rows, 0) + assert.NoError(t, result.Err) + + ensureConnValid(t, pgConn) +} + +// https://github.com/jackc/pgx/issues/859 +func TestResultReaderValuesHaveSameCapacityAsLength(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + result := pgConn.ExecParams(context.Background(), "select $1::text as msg", [][]byte{[]byte("Hello, world")}, nil, nil, nil) + require.Len(t, result.FieldDescriptions(), 1) + assert.Equal(t, "msg", result.FieldDescriptions()[0].Name) + + rowCount := 0 + for result.NextRow() { + rowCount += 1 + assert.Equal(t, "Hello, world", string(result.Values()[0])) + assert.Equal(t, len(result.Values()[0]), cap(result.Values()[0])) + } + assert.Equal(t, 1, rowCount) + commandTag, err := result.Close() + assert.Equal(t, "SELECT 1", commandTag.String()) + assert.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestConnExecPrepared(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + psd, err := pgConn.Prepare(context.Background(), "ps1", "select $1::text as msg", nil) + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, 1) + assert.Len(t, psd.Fields, 1) + + result := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{[]byte("Hello, world")}, nil, nil) + require.Len(t, result.FieldDescriptions(), 1) + assert.Equal(t, "msg", result.FieldDescriptions()[0].Name) + + rowCount := 0 + for result.NextRow() { + rowCount += 1 + assert.Equal(t, "Hello, world", string(result.Values()[0])) + } + assert.Equal(t, 1, rowCount) + commandTag, err := result.Close() + assert.Equal(t, "SELECT 1", commandTag.String()) + assert.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestConnExecPreparedMaxNumberOfParams(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + paramCount := math.MaxUint16 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") + + psd, err := pgConn.Prepare(context.Background(), "ps1", sql, nil) + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, paramCount) + assert.Len(t, psd.Fields, 1) + + result := pgConn.ExecPrepared(context.Background(), "ps1", args, nil, nil).Read() + require.NoError(t, result.Err) + require.Len(t, result.Rows, paramCount) + + ensureConnValid(t, pgConn) +} + +func TestConnExecPreparedTooManyParams(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + paramCount := math.MaxUint16 + 1 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") + + psd, err := pgConn.Prepare(context.Background(), "ps1", sql, nil) + if pgConn.ParameterStatus("crdb_version") != "" { + // CockroachDB rejects preparing a statement with more than 65535 parameters. + require.EqualError(t, err, "ERROR: more than 65535 arguments to prepared statement: 65536 (SQLSTATE 08P01)") + } else { + // PostgreSQL accepts preparing a statement with more than 65535 parameters and only fails when executing it through the extended protocol. + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, paramCount) + assert.Len(t, psd.Fields, 1) + + result := pgConn.ExecPrepared(context.Background(), "ps1", args, nil, nil).Read() + require.EqualError(t, result.Err, "extended protocol limited to 65535 parameters") + } + + ensureConnValid(t, pgConn) +} + +func TestConnExecPreparedCanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Prepare(context.Background(), "ps1", "select current_database(), pg_sleep(1)", nil) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil) + rowCount := 0 + for result.NextRow() { + rowCount += 1 + } + assert.Equal(t, 0, rowCount) + commandTag, err := result.Close() + assert.Equal(t, pgconn.CommandTag{}, commandTag) + assert.True(t, pgconn.Timeout(err)) + assert.True(t, pgConn.IsClosed()) + select { + case <-pgConn.CleanupDone(): + case <-time.After(5 * time.Second): + t.Fatal("Connection cleanup exceeded maximum time") + } +} + +func TestConnExecPreparedPrecanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Prepare(context.Background(), "ps1", "select current_database(), pg_sleep(1)", nil) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read() + require.Error(t, result.Err) + assert.True(t, errors.Is(result.Err, context.Canceled)) + assert.True(t, pgconn.SafeToRetry(result.Err)) + + ensureConnValid(t, pgConn) +} + +func TestConnExecPreparedEmptySQL(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Prepare(ctx, "ps1", "", nil) + require.NoError(t, err) + + result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read() + assert.Equal(t, pgconn.CommandTag{}, result.CommandTag) + assert.Len(t, result.Rows, 0) + assert.NoError(t, result.Err) + + ensureConnValid(t, pgConn) +} + +func TestConnExecBatch(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) + require.NoError(t, err) + + batch := &pgconn.Batch{} + + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil) + batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil) + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil) + results, err := pgConn.ExecBatch(context.Background(), batch).ReadAll() + require.NoError(t, err) + require.Len(t, results, 3) + + require.Len(t, results[0].Rows, 1) + require.Equal(t, "ExecParams 1", string(results[0].Rows[0][0])) + assert.Equal(t, "SELECT 1", results[0].CommandTag.String()) + + require.Len(t, results[1].Rows, 1) + require.Equal(t, "ExecPrepared 1", string(results[1].Rows[0][0])) + assert.Equal(t, "SELECT 1", results[1].CommandTag.String()) + + require.Len(t, results[2].Rows, 1) + require.Equal(t, "ExecParams 2", string(results[2].Rows[0][0])) + assert.Equal(t, "SELECT 1", results[2].CommandTag.String()) +} + +func TestConnExecBatchDeferredError(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") + } + + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); + + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + + _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() + require.NoError(t, err) + + batch := &pgconn.Batch{} + + batch.ExecParams(`update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil) + _, err = pgConn.ExecBatch(context.Background(), batch).ReadAll() + require.NotNil(t, err) + var pgErr *pgconn.PgError + require.True(t, errors.As(err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) + + ensureConnValid(t, pgConn) +} + +func TestConnExecBatchPrecanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) + require.NoError(t, err) + + batch := &pgconn.Batch{} + + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil) + batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil) + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err = pgConn.ExecBatch(ctx, batch).ReadAll() + require.Error(t, err) + assert.True(t, errors.Is(err, context.Canceled)) + assert.True(t, pgconn.SafeToRetry(err)) + + ensureConnValid(t, pgConn) +} + +// Without concurrent reading and writing large batches can deadlock. +// +// See https://github.com/jackc/pgx/issues/374. +func TestConnExecBatchHuge(t *testing.T) { + if testing.Short() { + t.Skip("skipping test in short mode.") + } + + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + batch := &pgconn.Batch{} + + queryCount := 100000 + args := make([]string, queryCount) + + for i := range args { + args[i] = strconv.Itoa(i) + batch.ExecParams("select $1::text", [][]byte{[]byte(args[i])}, nil, nil, nil) + } + + results, err := pgConn.ExecBatch(context.Background(), batch).ReadAll() + require.NoError(t, err) + require.Len(t, results, queryCount) + + for i := range args { + require.Len(t, results[i].Rows, 1) + require.Equal(t, args[i], string(results[i].Rows[0][0])) + assert.Equal(t, "SELECT 1", results[i].CommandTag.String()) + } +} + +func TestConnExecBatchImplicitTransaction(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Skipping due to known server issue: (https://github.com/cockroachdb/cockroach/issues/44803)") + } + + _, err = pgConn.Exec(context.Background(), "create temporary table t(id int)").ReadAll() + require.NoError(t, err) + + batch := &pgconn.Batch{} + + batch.ExecParams("insert into t(id) values(1)", nil, nil, nil, nil) + batch.ExecParams("insert into t(id) values(2)", nil, nil, nil, nil) + batch.ExecParams("insert into t(id) values(3)", nil, nil, nil, nil) + batch.ExecParams("select 1/0", nil, nil, nil, nil) + _, err = pgConn.ExecBatch(context.Background(), batch).ReadAll() + require.Error(t, err) + + result := pgConn.ExecParams(context.Background(), "select count(*) from t", nil, nil, nil, nil).Read() + require.Equal(t, "0", string(result.Rows[0][0])) +} + +func TestConnLocking(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + mrr := pgConn.Exec(context.Background(), "select 'Hello, world'") + _, err = pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() + assert.Error(t, err) + assert.Equal(t, "conn busy", err.Error()) + assert.True(t, pgconn.SafeToRetry(err)) + + results, err := mrr.ReadAll() + assert.NoError(t, err) + assert.Len(t, results, 1) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", results[0].CommandTag.String()) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + + ensureConnValid(t, pgConn) +} + +func TestConnOnNotice(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + var msg string + config.OnNotice = func(c *pgconn.PgConn, notice *pgconn.Notice) { + msg = notice.Message + } + config.RuntimeParams["client_min_messages"] = "notice" // Ensure we only get the message we expect. + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support PL/PGSQL (https://github.com/cockroachdb/cockroach/issues/17511)") + } + + multiResult := pgConn.Exec(context.Background(), `do $$ +begin + raise notice 'hello, world'; +end$$;`) + err = multiResult.Close() + require.NoError(t, err) + assert.Equal(t, "hello, world", msg) + + ensureConnValid(t, pgConn) +} + +func TestConnOnNotification(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + var msg string + config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { + msg = n.Payload + } + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)") + } + + _, err = pgConn.Exec(context.Background(), "listen foo").ReadAll() + require.NoError(t, err) + + notifier, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, notifier) + _, err = notifier.Exec(context.Background(), "notify foo, 'bar'").ReadAll() + require.NoError(t, err) + + _, err = pgConn.Exec(context.Background(), "select 1").ReadAll() + require.NoError(t, err) + + assert.Equal(t, "bar", msg) + + ensureConnValid(t, pgConn) +} + +func TestConnWaitForNotification(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + var msg string + config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { + msg = n.Payload + } + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)") + } + + _, err = pgConn.Exec(context.Background(), "listen foo").ReadAll() + require.NoError(t, err) + + notifier, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, notifier) + _, err = notifier.Exec(context.Background(), "notify foo, 'bar'").ReadAll() + require.NoError(t, err) + + err = pgConn.WaitForNotification(context.Background()) + require.NoError(t, err) + + assert.Equal(t, "bar", msg) + + ensureConnValid(t, pgConn) +} + +func TestConnWaitForNotificationPrecanceled(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + err = pgConn.WaitForNotification(ctx) + require.ErrorIs(t, err, context.Canceled) + + ensureConnValid(t, pgConn) +} + +func TestConnWaitForNotificationTimeout(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) + err = pgConn.WaitForNotification(ctx) + cancel() + assert.True(t, pgconn.Timeout(err)) + assert.ErrorIs(t, err, context.DeadlineExceeded) + + ensureConnValid(t, pgConn) +} + +func TestConnCopyToSmall(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does support COPY TO") + } + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g json + )`).ReadAll() + require.NoError(t, err) + + _, err = pgConn.Exec(context.Background(), `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}')`).ReadAll() + require.NoError(t, err) + + _, err = pgConn.Exec(context.Background(), `insert into foo values (null, null, null, null, null, null, null)`).ReadAll() + require.NoError(t, err) + + inputBytes := []byte("0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\n" + + "\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n") + + outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) + + res, err := pgConn.CopyTo(context.Background(), outputWriter, "copy foo to stdout") + require.NoError(t, err) + + assert.Equal(t, int64(2), res.RowsAffected()) + assert.Equal(t, inputBytes, outputWriter.Bytes()) + + ensureConnValid(t, pgConn) +} + +func TestConnCopyToLarge(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does support COPY TO") + } + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g json, + h bytea + )`).ReadAll() + require.NoError(t, err) + + inputBytes := make([]byte, 0) + + for i := 0; i < 1000; i++ { + _, err = pgConn.Exec(context.Background(), `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}', 'oooo')`).ReadAll() + require.NoError(t, err) + inputBytes = append(inputBytes, "0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\t\\\\x6f6f6f6f\n"...) + } + + outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) + + res, err := pgConn.CopyTo(context.Background(), outputWriter, "copy foo to stdout") + require.NoError(t, err) + + assert.Equal(t, int64(1000), res.RowsAffected()) + assert.Equal(t, inputBytes, outputWriter.Bytes()) + + ensureConnValid(t, pgConn) +} + +func TestConnCopyToQueryError(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + outputWriter := bytes.NewBuffer(make([]byte, 0)) + + res, err := pgConn.CopyTo(context.Background(), outputWriter, "cropy foo to stdout") + require.Error(t, err) + assert.IsType(t, &pgconn.PgError{}, err) + assert.Equal(t, int64(0), res.RowsAffected()) + + ensureConnValid(t, pgConn) +} + +func TestConnCopyToCanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support query cancellation (https://github.com/cockroachdb/cockroach/issues/41335)") + } + + outputWriter := &bytes.Buffer{} + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout") + assert.Error(t, err) + assert.Equal(t, pgconn.CommandTag{}, res) + + assert.True(t, pgConn.IsClosed()) + select { + case <-pgConn.CleanupDone(): + case <-time.After(5 * time.Second): + t.Fatal("Connection cleanup exceeded maximum time") + } +} + +func TestConnCopyToPrecanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + outputWriter := &bytes.Buffer{} + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select * from generate_series(1,1000)) to stdout") + require.Error(t, err) + assert.True(t, errors.Is(err, context.Canceled)) + assert.True(t, pgconn.SafeToRetry(err)) + assert.Equal(t, pgconn.CommandTag{}, res) + + ensureConnValid(t, pgConn) +} + +func TestConnCopyFrom(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not fully support COPY FROM (https://www.cockroachlabs.com/docs/v20.2/copy-from.html)") + } + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) + + srcBuf := &bytes.Buffer{} + + inputRows := [][][]byte{} + for i := 0; i < 1000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) + _, err = srcBuf.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + require.NoError(t, err) + } + + ct, err := pgConn.CopyFrom(context.Background(), srcBuf, "COPY foo FROM STDIN WITH (FORMAT csv)") + require.NoError(t, err) + assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) + + result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + + assert.Equal(t, inputRows, result.Rows) + + ensureConnValid(t, pgConn) +} + +func TestConnCopyFromCanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support query cancellation (https://github.com/cockroachdb/cockroach/issues/41335)") + } + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) + + r, w := io.Pipe() + go func() { + for i := 0; i < 1000000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + _, err := w.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + if err != nil { + return + } + time.Sleep(time.Microsecond) + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)") + cancel() + assert.Equal(t, int64(0), ct.RowsAffected()) + assert.Error(t, err) + + assert.True(t, pgConn.IsClosed()) + select { + case <-pgConn.CleanupDone(): + case <-time.After(5 * time.Second): + t.Fatal("Connection cleanup exceeded maximum time") + } +} + +func TestConnCopyFromPrecanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) + + r, w := io.Pipe() + go func() { + for i := 0; i < 1000000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + _, err := w.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + if err != nil { + return + } + time.Sleep(time.Microsecond) + } + }() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)") + require.Error(t, err) + assert.True(t, errors.Is(err, context.Canceled)) + assert.True(t, pgconn.SafeToRetry(err)) + assert.Equal(t, pgconn.CommandTag{}, ct) + + ensureConnValid(t, pgConn) +} + +func TestConnCopyFromGzipReader(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not fully support COPY FROM (https://www.cockroachlabs.com/docs/v20.2/copy-from.html)") + } + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) + + f, err := ioutil.TempFile("", "*") + require.NoError(t, err) + + gw := gzip.NewWriter(f) + + inputRows := [][][]byte{} + for i := 0; i < 1000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) + _, err = gw.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + require.NoError(t, err) + } + + err = gw.Close() + require.NoError(t, err) + + _, err = f.Seek(0, 0) + require.NoError(t, err) + + gr, err := gzip.NewReader(f) + require.NoError(t, err) + + ct, err := pgConn.CopyFrom(context.Background(), gr, "COPY foo FROM STDIN WITH (FORMAT csv)") + require.NoError(t, err) + assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) + + err = gr.Close() + require.NoError(t, err) + + err = f.Close() + require.NoError(t, err) + + err = os.Remove(f.Name()) + require.NoError(t, err) + + result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + + assert.Equal(t, inputRows, result.Rows) + + ensureConnValid(t, pgConn) +} + +func TestConnCopyFromQuerySyntaxError(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) + + srcBuf := &bytes.Buffer{} + + // Send data even though the COPY FROM command will be rejected with a syntax error. This ensures that this does not + // break the connection. See https://github.com/jackc/pgconn/pull/127 for context. + inputRows := [][][]byte{} + for i := 0; i < 1000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) + _, err = srcBuf.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + require.NoError(t, err) + } + + res, err := pgConn.CopyFrom(context.Background(), srcBuf, "cropy foo FROM STDIN WITH (FORMAT csv)") + require.Error(t, err) + assert.IsType(t, &pgconn.PgError{}, err) + assert.Equal(t, int64(0), res.RowsAffected()) + + ensureConnValid(t, pgConn) +} + +func TestConnCopyFromQueryNoTableError(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + srcBuf := &bytes.Buffer{} + + res, err := pgConn.CopyFrom(context.Background(), srcBuf, "copy foo to stdout") + require.Error(t, err) + assert.IsType(t, &pgconn.PgError{}, err) + assert.Equal(t, int64(0), res.RowsAffected()) + + ensureConnValid(t, pgConn) +} + +// https://github.com/jackc/pgconn/issues/21 +func TestConnCopyFromNoticeResponseReceivedMidStream(t *testing.T) { + t.Parallel() + + ctx := context.Background() + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support triggers (https://github.com/cockroachdb/cockroach/issues/28296)") + } + + _, err = pgConn.Exec(ctx, `create temporary table sentences( + t text, + ts tsvector + )`).ReadAll() + require.NoError(t, err) + + _, err = pgConn.Exec(ctx, `create function pg_temp.sentences_trigger() returns trigger as $$ + begin + new.ts := to_tsvector(new.t); + return new; + end + $$ language plpgsql;`).ReadAll() + require.NoError(t, err) + + _, err = pgConn.Exec(ctx, `create trigger sentences_update before insert on sentences for each row execute procedure pg_temp.sentences_trigger();`).ReadAll() + require.NoError(t, err) + + longString := make([]byte, 10001) + for i := range longString { + longString[i] = 'x' + } + + buf := &bytes.Buffer{} + for i := 0; i < 1000; i++ { + buf.Write([]byte(fmt.Sprintf("%s\n", string(longString)))) + } + + _, err = pgConn.CopyFrom(ctx, buf, "COPY sentences(t) FROM STDIN WITH (FORMAT csv)") + require.NoError(t, err) +} + +func TestConnEscapeString(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + tests := []struct { + in string + out string + }{ + {in: "", out: ""}, + {in: "42", out: "42"}, + {in: "'", out: "''"}, + {in: "hi'there", out: "hi''there"}, + {in: "'hi there'", out: "''hi there''"}, + } + + for i, tt := range tests { + value, err := pgConn.EscapeString(tt.in) + if assert.NoErrorf(t, err, "%d.", i) { + assert.Equalf(t, tt.out, value, "%d.", i) + } + } + + ensureConnValid(t, pgConn) +} + +func TestConnCancelRequest(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support query cancellation (https://github.com/cockroachdb/cockroach/issues/41335)") + } + + multiResult := pgConn.Exec(context.Background(), "select 'Hello, world', pg_sleep(2)") + + go func() { + // The query is actually sent when multiResult.NextResult() is called. So wait to ensure it is sent. + // Once Flush is available this could use that instead. + time.Sleep(500 * time.Millisecond) + + err := pgConn.CancelRequest(context.Background()) + require.NoError(t, err) + }() + + for multiResult.NextResult() { + } + err = multiResult.Close() + + require.IsType(t, &pgconn.PgError{}, err) + require.Equal(t, "57014", err.(*pgconn.PgError).Code) + + ensureConnValid(t, pgConn) +} + +// https://github.com/jackc/pgx/issues/659 +func TestConnContextCanceledCancelsRunningQueryOnServer(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pid := pgConn.PID() + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + multiResult := pgConn.Exec(ctx, "select 'Hello, world', pg_sleep(30)") + + for multiResult.NextResult() { + } + err = multiResult.Close() + assert.True(t, pgconn.Timeout(err)) + assert.True(t, pgConn.IsClosed()) + select { + case <-pgConn.CleanupDone(): + case <-time.After(5 * time.Second): + t.Fatal("Connection cleanup exceeded maximum time") + } + + otherConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, otherConn) + + ctx, cancel = context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + for { + result := otherConn.ExecParams(ctx, + `select 1 from pg_stat_activity where pid=$1`, + [][]byte{[]byte(strconv.FormatInt(int64(pid), 10))}, + nil, + nil, + nil, + ).Read() + require.NoError(t, result.Err) + + if len(result.Rows) == 0 { + break + } + } +} + +func TestHijackAndConstruct(t *testing.T) { + t.Parallel() + + origConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + hc, err := origConn.Hijack() + require.NoError(t, err) + + _, err = origConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() + require.Error(t, err) + + newConn, err := pgconn.Construct(hc) + require.NoError(t, err) + + defer closeConn(t, newConn) + + results, err := newConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() + assert.NoError(t, err) + + assert.Len(t, results, 1) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", results[0].CommandTag.String()) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + + ensureConnValid(t, newConn) +} + +func TestConnCloseWhileCancellableQueryInProgress(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + ctx, _ := context.WithCancel(context.Background()) + pgConn.Exec(ctx, "select n from generate_series(1,10) n") + + closeCtx, _ := context.WithCancel(context.Background()) + pgConn.Close(closeCtx) + select { + case <-pgConn.CleanupDone(): + case <-time.After(5 * time.Second): + t.Fatal("Connection cleanup exceeded maximum time") + } +} + +// https://github.com/jackc/pgx/issues/800 +func TestFatalErrorReceivedAfterCommandComplete(t *testing.T) { + t.Parallel() + + steps := pgmock.AcceptUnauthenticatedConnRequestSteps() + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Parse{})) + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Bind{})) + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Describe{})) + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Execute{})) + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Sync{})) + steps = append(steps, pgmock.SendMessage(&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{ + {Name: []byte("mock")}, + }})) + steps = append(steps, pgmock.SendMessage(&pgproto3.CommandComplete{CommandTag: []byte("SELECT 0")})) + steps = append(steps, pgmock.SendMessage(&pgproto3.ErrorResponse{Severity: "FATAL", Code: "57P01"})) + + script := &pgmock.Script{Steps: steps} + + ln, err := net.Listen("tcp", "127.0.0.1:") + require.NoError(t, err) + defer ln.Close() + + serverErrChan := make(chan error, 1) + go func() { + defer close(serverErrChan) + + conn, err := ln.Accept() + if err != nil { + serverErrChan <- err + return + } + defer conn.Close() + + err = conn.SetDeadline(time.Now().Add(5 * time.Second)) + if err != nil { + serverErrChan <- err + return + } + + err = script.Run(pgproto3.NewBackend(conn, conn)) + if err != nil { + serverErrChan <- err + return + } + }() + + parts := strings.Split(ln.Addr().String(), ":") + host := parts[0] + port := parts[1] + connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + conn, err := pgconn.Connect(ctx, connStr) + require.NoError(t, err) + + rr := conn.ExecParams(ctx, "mocked...", nil, nil, nil, nil) + + for rr.NextRow() { + } + + _, err = rr.Close() + require.Error(t, err) +} + +// https://github.com/jackc/pgconn/issues/27 +func TestConnLargeResponseWhileWritingDoesNotDeadlock(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), "set client_min_messages = debug5").ReadAll() + require.NoError(t, err) + + // The actual contents of this test aren't important. What's important is a large amount of data to be written and + // because of client_min_messages = debug5 the server will return a large amount of data. + + paramCount := math.MaxUint16 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") + + result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read() + require.NoError(t, result.Err) + require.Len(t, result.Rows, paramCount) + + ensureConnValid(t, pgConn) +} + +func TestConnCheckConn(t *testing.T) { + t.Parallel() + + // Intentionally using TCP connection for more predictable close behavior. (Not sure if Unix domain sockets would behave subtlely different.) + + connString := os.Getenv(os.Getenv("PGX_TEST_TCP_CONN_STRING")) + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") + } + + c1, err := pgconn.Connect(context.Background(), connString) + require.NoError(t, err) + defer c1.Close(context.Background()) + + if c1.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)") + } + + err = c1.CheckConn() + require.NoError(t, err) + + c2, err := pgconn.Connect(context.Background(), connString) + require.NoError(t, err) + defer c2.Close(context.Background()) + + _, err = c2.Exec(context.Background(), fmt.Sprintf("select pg_terminate_backend(%d)", c1.PID())).ReadAll() + require.NoError(t, err) + + // Give a little time for the signal to actually kill the backend. + time.Sleep(500 * time.Millisecond) + + err = c1.CheckConn() + require.Error(t, err) +} + +func TestPipelinePrepare(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + result := pgConn.ExecParams(context.Background(), `create temporary table t (id text primary key)`, nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + + pipeline := pgConn.StartPipeline(context.Background()) + pipeline.SendPrepare("selectInt", "select $1::bigint as a", nil) + pipeline.SendPrepare("selectText", "select $1::text as b", nil) + pipeline.SendPrepare("selectNoParams", "select 42 as c", nil) + pipeline.SendPrepare("insertNoResults", "insert into t (id) values ($1)", nil) + pipeline.SendPrepare("insertNoParamsOrResults", "insert into t (id) values ('foo')", nil) + err = pipeline.Sync() + require.NoError(t, err) + + results, err := pipeline.GetResults() + require.NoError(t, err) + sd, ok := results.(*pgconn.StatementDescription) + require.Truef(t, ok, "expected StatementDescription, got: %#v", results) + require.Len(t, sd.Fields, 1) + require.Equal(t, string(sd.Fields[0].Name), "a") + require.Equal(t, []uint32{pgtype.Int8OID}, sd.ParamOIDs) + + results, err = pipeline.GetResults() + require.NoError(t, err) + sd, ok = results.(*pgconn.StatementDescription) + require.Truef(t, ok, "expected StatementDescription, got: %#v", results) + require.Len(t, sd.Fields, 1) + require.Equal(t, string(sd.Fields[0].Name), "b") + require.Equal(t, []uint32{pgtype.TextOID}, sd.ParamOIDs) + + results, err = pipeline.GetResults() + require.NoError(t, err) + sd, ok = results.(*pgconn.StatementDescription) + require.Truef(t, ok, "expected StatementDescription, got: %#v", results) + require.Len(t, sd.Fields, 1) + require.Equal(t, string(sd.Fields[0].Name), "c") + require.Equal(t, []uint32{}, sd.ParamOIDs) + + results, err = pipeline.GetResults() + require.NoError(t, err) + sd, ok = results.(*pgconn.StatementDescription) + require.Truef(t, ok, "expected StatementDescription, got: %#v", results) + require.Len(t, sd.Fields, 0) + require.Equal(t, []uint32{pgtype.TextOID}, sd.ParamOIDs) + + results, err = pipeline.GetResults() + require.NoError(t, err) + sd, ok = results.(*pgconn.StatementDescription) + require.Truef(t, ok, "expected StatementDescription, got: %#v", results) + require.Len(t, sd.Fields, 0) + require.Len(t, sd.ParamOIDs, 0) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + err = pipeline.Close() + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestPipelinePrepareError(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pipeline := pgConn.StartPipeline(context.Background()) + pipeline.SendPrepare("selectInt", "select $1::bigint as a", nil) + pipeline.SendPrepare("selectError", "bad", nil) + pipeline.SendPrepare("selectText", "select $1::text as b", nil) + err = pipeline.Sync() + require.NoError(t, err) + + results, err := pipeline.GetResults() + require.NoError(t, err) + sd, ok := results.(*pgconn.StatementDescription) + require.Truef(t, ok, "expected StatementDescription, got: %#v", results) + require.Len(t, sd.Fields, 1) + require.Equal(t, string(sd.Fields[0].Name), "a") + require.Equal(t, []uint32{pgtype.Int8OID}, sd.ParamOIDs) + + results, err = pipeline.GetResults() + var pgErr *pgconn.PgError + require.ErrorAs(t, err, &pgErr) + require.Nil(t, results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + err = pipeline.Close() + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestPipelinePrepareAndDeallocate(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pipeline := pgConn.StartPipeline(context.Background()) + pipeline.SendPrepare("selectInt", "select $1::bigint as a", nil) + pipeline.SendDeallocate("selectInt") + err = pipeline.Sync() + require.NoError(t, err) + + results, err := pipeline.GetResults() + require.NoError(t, err) + sd, ok := results.(*pgconn.StatementDescription) + require.Truef(t, ok, "expected StatementDescription, got: %#v", results) + require.Len(t, sd.Fields, 1) + require.Equal(t, string(sd.Fields[0].Name), "a") + require.Equal(t, []uint32{pgtype.Int8OID}, sd.ParamOIDs) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.CloseComplete) + require.Truef(t, ok, "expected CloseComplete, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + err = pipeline.Close() + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestPipelineQuery(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pipeline := pgConn.StartPipeline(context.Background()) + pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + results, err := pipeline.GetResults() + require.NoError(t, err) + rr, ok := results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult := rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "1", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "2", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "3", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "4", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "5", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + err = pipeline.Close() + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestPipelinePrepareQuery(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pipeline := pgConn.StartPipeline(context.Background()) + pipeline.SendPrepare("ps", "select $1::text as msg", nil) + pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("hello")}, nil, nil) + pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("goodbye")}, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + results, err := pipeline.GetResults() + require.NoError(t, err) + sd, ok := results.(*pgconn.StatementDescription) + require.Truef(t, ok, "expected StatementDescription, got: %#v", results) + require.Len(t, sd.Fields, 1) + require.Equal(t, string(sd.Fields[0].Name), "msg") + require.Equal(t, []uint32{pgtype.TextOID}, sd.ParamOIDs) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok := results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult := rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "hello", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "goodbye", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + err = pipeline.Close() + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestPipelineQueryErrorBetweenSyncs(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pipeline := pgConn.StartPipeline(context.Background()) + pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 1/(3-n) from generate_series(1,10) n`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 6`, nil, nil, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + results, err := pipeline.GetResults() + require.NoError(t, err) + rr, ok := results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult := rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "1", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "2", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "3", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + var pgErr *pgconn.PgError + require.ErrorAs(t, readResult.Err, &pgErr) + require.Equal(t, "22012", pgErr.Code) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "5", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "6", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + err = pipeline.Close() + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestPipelineCloseReadsUnreadResults(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pipeline := pgConn.StartPipeline(context.Background()) + pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + results, err := pipeline.GetResults() + require.NoError(t, err) + rr, ok := results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult := rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "1", string(readResult.Rows[0][0])) + + err = pipeline.Close() + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestPipelineCloseDetectsUnsyncedRequests(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pipeline := pgConn.StartPipeline(context.Background()) + pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil) + + results, err := pipeline.GetResults() + require.NoError(t, err) + rr, ok := results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult := rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "1", string(readResult.Rows[0][0])) + + err = pipeline.Close() + require.EqualError(t, err, "pipeline has unsynced requests") +} + +func Example() { + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + if err != nil { + log.Fatalln(err) + } + defer pgConn.Close(context.Background()) + + result := pgConn.ExecParams(context.Background(), "select generate_series(1,3)", nil, nil, nil, nil).Read() + if result.Err != nil { + log.Fatalln(result.Err) + } + + for _, row := range result.Rows { + fmt.Println(string(row[0])) + } + + fmt.Println(result.CommandTag) + // Output: + // 1 + // 2 + // 3 + // SELECT 3 +} + +func GetSSLPassword(ctx context.Context) string { + connString := os.Getenv("PGX_SSL_PASSWORD") + return connString +} + +var rsaCertPEM = `-----BEGIN CERTIFICATE----- +MIIDCTCCAfGgAwIBAgIUQDlN1g1bzxIJ8KWkayNcQY5gzMEwDQYJKoZIhvcNAQEL +BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTIyMDgxNTIxNDgyNloXDTIzMDgx +NTIxNDgyNlowFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF +AAOCAQ8AMIIBCgKCAQEA0vOppiT8zE+076acRORzD5JVbRYKMK3XlWLVrHua4+ct +Rm54WyP+3XsYU4JGGGKgb8E+u2UosGJYcSM+b+U1/5XPTcpuumS+pCiD9WP++A39 +tsukYwR7m65cgpiI4dlLEZI3EWpAW+Bb3230KiYW4sAmQ0Ih4PrN+oPvzcs86F4d +9Y03CqVUxRKLBLaClZQAg8qz2Pawwj1FKKjDX7u2fRVR0wgOugpCMOBJMcCgz9pp +0HSa4x3KZDHEZY7Pah5XwWrCfAEfRWsSTGcNaoN8gSxGFM1JOEJa8SAuPGjFcYIv +MmVWdw0FXCgYlSDL02fzLE0uyvXBDibzSqOk770JhQIDAQABo1MwUTAdBgNVHQ4E +FgQUiJ8JLENJ+2k1Xl4o6y2Lc/qHTh0wHwYDVR0jBBgwFoAUiJ8JLENJ+2k1Xl4o +6y2Lc/qHTh0wDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAwjn2 +gnNAhFvh58VqLIjU6ftvn6rhz5B9dg2+XyY8sskLhhkO1nL9339BVZsRt+eI3a7I +81GNIm9qHVM3MUAcQv3SZy+0UPVUT8DNH2LwHT3CHnYTBP8U+8n8TDNGSTMUhIBB +Rx+6KwODpwLdI79VGT3IkbU9bZwuepB9I9nM5t/tt5kS4gHmJFlO0aLJFCTO4Scf +hp/WLPv4XQUH+I3cPfaJRxz2j0Kc8iOzMhFmvl1XOGByjX6X33LnOzY/LVeTSGyS +VgC32BGtnMwuy5XZYgFAeUx9HKy4tG4OH2Ux6uPF/WAhsug6PXSjV7BK6wYT5i27 +MlascjupnaptKX/wMA== +-----END CERTIFICATE----- +` + +var rsaKeyPEM = testingKey(`-----BEGIN TESTING KEY----- +MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDS86mmJPzMT7Tv +ppxE5HMPklVtFgowrdeVYtWse5rj5y1GbnhbI/7dexhTgkYYYqBvwT67ZSiwYlhx +Iz5v5TX/lc9Nym66ZL6kKIP1Y/74Df22y6RjBHubrlyCmIjh2UsRkjcRakBb4Fvf +bfQqJhbiwCZDQiHg+s36g+/NyzzoXh31jTcKpVTFEosEtoKVlACDyrPY9rDCPUUo +qMNfu7Z9FVHTCA66CkIw4EkxwKDP2mnQdJrjHcpkMcRljs9qHlfBasJ8AR9FaxJM +Zw1qg3yBLEYUzUk4QlrxIC48aMVxgi8yZVZ3DQVcKBiVIMvTZ/MsTS7K9cEOJvNK +o6TvvQmFAgMBAAECggEAKzTK54Ol33bn2TnnwdiElIjlRE2CUswYXrl6iDRc2hbs +WAOiVRB/T/+5UMla7/2rXJhY7+rdNZs/ABU24ZYxxCJ77jPrD/Q4c8j0lhsgCtBa +ycjV543wf0dsHTd+ubtWu8eVzdRUUD0YtB+CJevdPh4a+CWgaMMV0xyYzi61T+Yv +Z7Uc3awIAiT4Kw9JRmJiTnyMJg5vZqW3BBAX4ZIvS/54ipwEU+9sWLcuH7WmCR0B +QCTqS6hfJDLm//dGC89Iyno57zfYuiT3PYCWH5crr/DH3LqnwlNaOGSBkhkXuIL+ +QvOaUMe2i0pjqxDrkBx05V554vyy9jEvK7i330HL4QKBgQDUJmouEr0+o7EMBApC +CPPu58K04qY5t9aGciG/pOurN42PF99yNZ1CnynH6DbcnzSl8rjc6Y65tzTlWods +bjwVfcmcokG7sPcivJvVjrjKpSQhL8xdZwSAjcqjN4yoJ/+ghm9w+SRmZr6oCQZ3 +1jREfJKT+PGiWTEjYcExPWUD2QKBgQD+jdgq4c3tFavU8Hjnlf75xbStr5qu+fp2 +SGLRRbX+msQwVbl2ZM9AJLoX9MTCl7D9zaI3ONhheMmfJ77lDTa3VMFtr3NevGA6 +MxbiCEfRtQpNkJnsqCixLckx3bskj5+IF9BWzw7y7nOzdhoWVFv/+TltTm3RB51G +McdlmmVjjQKBgQDSFAw2/YV6vtu2O1XxGC591/Bd8MaMBziev+wde3GHhaZfGVPC +I8dLTpMwCwowpFKdNeLLl1gnHX161I+f1vUWjw4TVjVjaBUBx+VEr2Tb/nXtiwiD +QV0a883CnGJjreAblKRMKdpasMmBWhaWmn39h6Iad3zHuCzJjaaiXNpn2QKBgQCf +k1Q8LanmQnuh1c41f7aD5gjKCRezMUpt9BrejhD1NxheJJ9LNQ8nat6uPedLBcUS +lmJms+AR2qKqf0QQWyQ98YgAtshgTz8TvQtPT1mWgSOgVFHqJdC8obNK63FyDgc4 +TZVxlgQNDqbBjfv0m5XA9f+mIlB9hYR2iKYzb4K30QKBgQC+LEJYZh00zsXttGHr +5wU1RzbgDIEsNuu+nZ4MxsaCik8ILNRHNXdeQbnADKuo6ATfhdmDIQMVZLG8Mivi +UwnwLd1GhizvqvLHa3ULnFphRyMGFxaLGV48axTT2ADoMX67ILrIY/yjycLqRZ3T +z3w+CgS20UrbLIR1YXfqUXge1g== +-----END TESTING KEY----- +`) + +func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") } + +func TestSNISupport(t *testing.T) { + t.Parallel() + tests := []struct { + name string + sni_param string + sni_set bool + }{ + { + name: "SNI is passed by default", + sni_param: "", + sni_set: true, + }, + { + name: "SNI is passed when asked for", + sni_param: "sslsni=1", + sni_set: true, + }, + { + name: "SNI is not passed when disabled", + sni_param: "sslsni=0", + sni_set: false, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ln, err := net.Listen("tcp", "127.0.0.1:") + require.NoError(t, err) + defer ln.Close() + + serverErrChan := make(chan error, 1) + serverSNINameChan := make(chan string, 1) + defer close(serverErrChan) + defer close(serverSNINameChan) + + go func() { + var sniHost string + + conn, err := ln.Accept() + if err != nil { + serverErrChan <- err + return + } + defer conn.Close() + + err = conn.SetDeadline(time.Now().Add(5 * time.Second)) + if err != nil { + serverErrChan <- err + return + } + + backend := pgproto3.NewBackend(conn, conn) + startupMessage, err := backend.ReceiveStartupMessage() + if err != nil { + serverErrChan <- err + return + } + + switch startupMessage.(type) { + case *pgproto3.SSLRequest: + _, err = conn.Write([]byte("S")) + if err != nil { + serverErrChan <- err + return + } + default: + serverErrChan <- fmt.Errorf("unexpected startup message: %#v", startupMessage) + return + } + + cert, err := tls.X509KeyPair([]byte(rsaCertPEM), []byte(rsaKeyPEM)) + if err != nil { + serverErrChan <- err + return + } + + srv := tls.Server(conn, &tls.Config{ + Certificates: []tls.Certificate{cert}, + GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) { + sniHost = argHello.ServerName + return nil, nil + }, + }) + defer srv.Close() + + if err := srv.Handshake(); err != nil { + serverErrChan <- fmt.Errorf("handshake: %v", err) + return + } + + srv.Write((&pgproto3.AuthenticationOk{}).Encode(nil)) + srv.Write((&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}).Encode(nil)) + srv.Write((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil)) + + serverSNINameChan <- sniHost + }() + + port := strings.Split(ln.Addr().String(), ":")[1] + connStr := fmt.Sprintf("sslmode=require host=localhost port=%s %s", port, tt.sni_param) + _, err = pgconn.Connect(context.Background(), connStr) + + select { + case sniHost := <-serverSNINameChan: + if tt.sni_set { + require.Equal(t, sniHost, "localhost") + } else { + require.Equal(t, sniHost, "") + } + case err = <-serverErrChan: + t.Fatalf("server failed with error: %+v", err) + case <-time.After(time.Millisecond * 100): + t.Fatal("exceeded connection timeout without erroring out") + } + }) + } +} diff --git a/pgproto3/README.md b/pgproto3/README.md new file mode 100644 index 00000000..79d3a68b --- /dev/null +++ b/pgproto3/README.md @@ -0,0 +1,7 @@ +# pgproto3 + +Package pgproto3 is a encoder and decoder of the PostgreSQL wire protocol version 3. + +pgproto3 can be used as a foundation for PostgreSQL drivers, proxies, mock servers, load balancers and more. + +See example/pgfortune for a playful example of a fake PostgreSQL server. diff --git a/pgproto3/authentication_cleartext_password.go b/pgproto3/authentication_cleartext_password.go new file mode 100644 index 00000000..d8f98b9a --- /dev/null +++ b/pgproto3/authentication_cleartext_password.go @@ -0,0 +1,52 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +// AuthenticationCleartextPassword is a message sent from the backend indicating that a clear-text password is required. +type AuthenticationCleartextPassword struct { +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*AuthenticationCleartextPassword) Backend() {} + +// Backend identifies this message as an authentication response. +func (*AuthenticationCleartextPassword) AuthenticationResponse() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *AuthenticationCleartextPassword) Decode(src []byte) error { + if len(src) != 4 { + return errors.New("bad authentication message size") + } + + authType := binary.BigEndian.Uint32(src) + + if authType != AuthTypeCleartextPassword { + return errors.New("bad auth type") + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *AuthenticationCleartextPassword) Encode(dst []byte) []byte { + dst = append(dst, 'R') + dst = pgio.AppendInt32(dst, 8) + dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword) + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src AuthenticationCleartextPassword) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "AuthenticationCleartextPassword", + }) +} diff --git a/pgproto3/authentication_gss.go b/pgproto3/authentication_gss.go new file mode 100644 index 00000000..0d234222 --- /dev/null +++ b/pgproto3/authentication_gss.go @@ -0,0 +1,59 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type AuthenticationGSS struct{} + +func (a *AuthenticationGSS) Backend() {} + +func (a *AuthenticationGSS) AuthenticationResponse() {} + +func (a *AuthenticationGSS) Decode(src []byte) error { + if len(src) < 4 { + return errors.New("authentication message too short") + } + + authType := binary.BigEndian.Uint32(src) + + if authType != AuthTypeGSS { + return errors.New("bad auth type") + } + return nil +} + +func (a *AuthenticationGSS) Encode(dst []byte) []byte { + dst = append(dst, 'R') + dst = pgio.AppendInt32(dst, 4) + dst = pgio.AppendUint32(dst, AuthTypeGSS) + return dst +} + +func (a *AuthenticationGSS) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data []byte + }{ + Type: "AuthenticationGSS", + }) +} + +func (a *AuthenticationGSS) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Type string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + return nil +} diff --git a/pgproto3/authentication_gss_continue.go b/pgproto3/authentication_gss_continue.go new file mode 100644 index 00000000..63789dc1 --- /dev/null +++ b/pgproto3/authentication_gss_continue.go @@ -0,0 +1,68 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type AuthenticationGSSContinue struct { + Data []byte +} + +func (a *AuthenticationGSSContinue) Backend() {} + +func (a *AuthenticationGSSContinue) AuthenticationResponse() {} + +func (a *AuthenticationGSSContinue) Decode(src []byte) error { + if len(src) < 4 { + return errors.New("authentication message too short") + } + + authType := binary.BigEndian.Uint32(src) + + if authType != AuthTypeGSSCont { + return errors.New("bad auth type") + } + + a.Data = src[4:] + return nil +} + +func (a *AuthenticationGSSContinue) Encode(dst []byte) []byte { + dst = append(dst, 'R') + dst = pgio.AppendInt32(dst, int32(len(a.Data))+8) + dst = pgio.AppendUint32(dst, AuthTypeGSSCont) + dst = append(dst, a.Data...) + return dst +} + +func (a *AuthenticationGSSContinue) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data []byte + }{ + Type: "AuthenticationGSSContinue", + Data: a.Data, + }) +} + +func (a *AuthenticationGSSContinue) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Type string + Data []byte + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + a.Data = msg.Data + return nil +} diff --git a/pgproto3/authentication_md5_password.go b/pgproto3/authentication_md5_password.go new file mode 100644 index 00000000..5671c84c --- /dev/null +++ b/pgproto3/authentication_md5_password.go @@ -0,0 +1,77 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +// AuthenticationMD5Password is a message sent from the backend indicating that an MD5 hashed password is required. +type AuthenticationMD5Password struct { + Salt [4]byte +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*AuthenticationMD5Password) Backend() {} + +// Backend identifies this message as an authentication response. +func (*AuthenticationMD5Password) AuthenticationResponse() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *AuthenticationMD5Password) Decode(src []byte) error { + if len(src) != 8 { + return errors.New("bad authentication message size") + } + + authType := binary.BigEndian.Uint32(src) + + if authType != AuthTypeMD5Password { + return errors.New("bad auth type") + } + + copy(dst.Salt[:], src[4:8]) + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *AuthenticationMD5Password) Encode(dst []byte) []byte { + dst = append(dst, 'R') + dst = pgio.AppendInt32(dst, 12) + dst = pgio.AppendUint32(dst, AuthTypeMD5Password) + dst = append(dst, src.Salt[:]...) + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src AuthenticationMD5Password) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Salt [4]byte + }{ + Type: "AuthenticationMD5Password", + Salt: src.Salt, + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *AuthenticationMD5Password) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Type string + Salt [4]byte + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.Salt = msg.Salt + return nil +} diff --git a/pgproto3/authentication_ok.go b/pgproto3/authentication_ok.go new file mode 100644 index 00000000..88d648ae --- /dev/null +++ b/pgproto3/authentication_ok.go @@ -0,0 +1,52 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +// AuthenticationOk is a message sent from the backend indicating that authentication was successful. +type AuthenticationOk struct { +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*AuthenticationOk) Backend() {} + +// Backend identifies this message as an authentication response. +func (*AuthenticationOk) AuthenticationResponse() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *AuthenticationOk) Decode(src []byte) error { + if len(src) != 4 { + return errors.New("bad authentication message size") + } + + authType := binary.BigEndian.Uint32(src) + + if authType != AuthTypeOk { + return errors.New("bad auth type") + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *AuthenticationOk) Encode(dst []byte) []byte { + dst = append(dst, 'R') + dst = pgio.AppendInt32(dst, 8) + dst = pgio.AppendUint32(dst, AuthTypeOk) + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src AuthenticationOk) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "AuthenticationOK", + }) +} diff --git a/pgproto3/authentication_sasl.go b/pgproto3/authentication_sasl.go new file mode 100644 index 00000000..59650d4c --- /dev/null +++ b/pgproto3/authentication_sasl.go @@ -0,0 +1,76 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +// AuthenticationSASL is a message sent from the backend indicating that SASL authentication is required. +type AuthenticationSASL struct { + AuthMechanisms []string +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*AuthenticationSASL) Backend() {} + +// Backend identifies this message as an authentication response. +func (*AuthenticationSASL) AuthenticationResponse() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *AuthenticationSASL) Decode(src []byte) error { + if len(src) < 4 { + return errors.New("authentication message too short") + } + + authType := binary.BigEndian.Uint32(src) + + if authType != AuthTypeSASL { + return errors.New("bad auth type") + } + + authMechanisms := src[4:] + for len(authMechanisms) > 1 { + idx := bytes.IndexByte(authMechanisms, 0) + if idx == -1 { + return &invalidMessageFormatErr{messageType: "AuthenticationSASL", details: "unterminated string"} + } + dst.AuthMechanisms = append(dst.AuthMechanisms, string(authMechanisms[:idx])) + authMechanisms = authMechanisms[idx+1:] + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *AuthenticationSASL) Encode(dst []byte) []byte { + dst = append(dst, 'R') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + dst = pgio.AppendUint32(dst, AuthTypeSASL) + + for _, s := range src.AuthMechanisms { + dst = append(dst, []byte(s)...) + dst = append(dst, 0) + } + dst = append(dst, 0) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src AuthenticationSASL) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + AuthMechanisms []string + }{ + Type: "AuthenticationSASL", + AuthMechanisms: src.AuthMechanisms, + }) +} diff --git a/pgproto3/authentication_sasl_continue.go b/pgproto3/authentication_sasl_continue.go new file mode 100644 index 00000000..2ce70a47 --- /dev/null +++ b/pgproto3/authentication_sasl_continue.go @@ -0,0 +1,81 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +// AuthenticationSASLContinue is a message sent from the backend containing a SASL challenge. +type AuthenticationSASLContinue struct { + Data []byte +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*AuthenticationSASLContinue) Backend() {} + +// Backend identifies this message as an authentication response. +func (*AuthenticationSASLContinue) AuthenticationResponse() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *AuthenticationSASLContinue) Decode(src []byte) error { + if len(src) < 4 { + return errors.New("authentication message too short") + } + + authType := binary.BigEndian.Uint32(src) + + if authType != AuthTypeSASLContinue { + return errors.New("bad auth type") + } + + dst.Data = src[4:] + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte { + dst = append(dst, 'R') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + dst = pgio.AppendUint32(dst, AuthTypeSASLContinue) + + dst = append(dst, src.Data...) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src AuthenticationSASLContinue) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data string + }{ + Type: "AuthenticationSASLContinue", + Data: string(src.Data), + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *AuthenticationSASLContinue) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Data string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.Data = []byte(msg.Data) + return nil +} diff --git a/pgproto3/authentication_sasl_final.go b/pgproto3/authentication_sasl_final.go new file mode 100644 index 00000000..a38a8b91 --- /dev/null +++ b/pgproto3/authentication_sasl_final.go @@ -0,0 +1,81 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +// AuthenticationSASLFinal is a message sent from the backend indicating a SASL authentication has completed. +type AuthenticationSASLFinal struct { + Data []byte +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*AuthenticationSASLFinal) Backend() {} + +// Backend identifies this message as an authentication response. +func (*AuthenticationSASLFinal) AuthenticationResponse() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *AuthenticationSASLFinal) Decode(src []byte) error { + if len(src) < 4 { + return errors.New("authentication message too short") + } + + authType := binary.BigEndian.Uint32(src) + + if authType != AuthTypeSASLFinal { + return errors.New("bad auth type") + } + + dst.Data = src[4:] + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte { + dst = append(dst, 'R') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + dst = pgio.AppendUint32(dst, AuthTypeSASLFinal) + + dst = append(dst, src.Data...) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Unmarshaler. +func (src AuthenticationSASLFinal) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data string + }{ + Type: "AuthenticationSASLFinal", + Data: string(src.Data), + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *AuthenticationSASLFinal) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Data string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.Data = []byte(msg.Data) + return nil +} diff --git a/pgproto3/backend.go b/pgproto3/backend.go new file mode 100644 index 00000000..09aeb7c8 --- /dev/null +++ b/pgproto3/backend.go @@ -0,0 +1,262 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" +) + +// Backend acts as a server for the PostgreSQL wire protocol version 3. +type Backend struct { + cr *chunkReader + w io.Writer + + // tracer is used to trace messages when Send or Receive is called. This means an outbound message is traced + // before it is actually transmitted (i.e. before Flush). + tracer *tracer + + wbuf []byte + + // Frontend message flyweights + bind Bind + cancelRequest CancelRequest + _close Close + copyFail CopyFail + copyData CopyData + copyDone CopyDone + describe Describe + execute Execute + flush Flush + functionCall FunctionCall + gssEncRequest GSSEncRequest + parse Parse + query Query + sslRequest SSLRequest + startupMessage StartupMessage + sync Sync + terminate Terminate + + bodyLen int + msgType byte + partialMsg bool + authType uint32 +} + +const ( + minStartupPacketLen = 4 // minStartupPacketLen is a single 32-bit int version or code. + maxStartupPacketLen = 10000 // maxStartupPacketLen is MAX_STARTUP_PACKET_LENGTH from PG source. +) + +// NewBackend creates a new Backend. +func NewBackend(r io.Reader, w io.Writer) *Backend { + cr := newChunkReader(r, 0) + return &Backend{cr: cr, w: w} +} + +// Send sends a message to the frontend (i.e. the client). The message is not guaranteed to be written until Flush is +// called. +func (b *Backend) Send(msg BackendMessage) { + prevLen := len(b.wbuf) + b.wbuf = msg.Encode(b.wbuf) + if b.tracer != nil { + b.tracer.traceMessage('B', int32(len(b.wbuf)-prevLen), msg) + } +} + +// Flush writes any pending messages to the frontend (i.e. the client). +func (b *Backend) Flush() error { + n, err := b.w.Write(b.wbuf) + + const maxLen = 1024 + if len(b.wbuf) > maxLen { + b.wbuf = make([]byte, 0, maxLen) + } else { + b.wbuf = b.wbuf[:0] + } + + if err != nil { + return &writeError{err: err, safeToRetry: n == 0} + } + + return nil +} + +// Trace starts tracing the message traffic to w. It writes in a similar format to that produced by the libpq function +// PQtrace. +func (b *Backend) Trace(w io.Writer, options TracerOptions) { + b.tracer = &tracer{ + w: w, + buf: &bytes.Buffer{}, + TracerOptions: options, + } +} + +// Untrace stops tracing. +func (b *Backend) Untrace() { + b.tracer = nil +} + +// ReceiveStartupMessage receives the initial connection message. This method is used of the normal Receive method +// because the initial connection message is "special" and does not include the message type as the first byte. This +// will return either a StartupMessage, SSLRequest, GSSEncRequest, or CancelRequest. +func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) { + buf, err := b.cr.Next(4) + if err != nil { + return nil, err + } + msgSize := int(binary.BigEndian.Uint32(buf) - 4) + + if msgSize < minStartupPacketLen || msgSize > maxStartupPacketLen { + return nil, fmt.Errorf("invalid length of startup packet: %d", msgSize) + } + + buf, err = b.cr.Next(msgSize) + if err != nil { + return nil, translateEOFtoErrUnexpectedEOF(err) + } + + code := binary.BigEndian.Uint32(buf) + + switch code { + case ProtocolVersionNumber: + err = b.startupMessage.Decode(buf) + if err != nil { + return nil, err + } + return &b.startupMessage, nil + case sslRequestNumber: + err = b.sslRequest.Decode(buf) + if err != nil { + return nil, err + } + return &b.sslRequest, nil + case cancelRequestCode: + err = b.cancelRequest.Decode(buf) + if err != nil { + return nil, err + } + return &b.cancelRequest, nil + case gssEncReqNumber: + err = b.gssEncRequest.Decode(buf) + if err != nil { + return nil, err + } + return &b.gssEncRequest, nil + default: + return nil, fmt.Errorf("unknown startup message code: %d", code) + } +} + +// Receive receives a message from the frontend. The returned message is only valid until the next call to Receive. +func (b *Backend) Receive() (FrontendMessage, error) { + if !b.partialMsg { + header, err := b.cr.Next(5) + if err != nil { + return nil, translateEOFtoErrUnexpectedEOF(err) + } + + b.msgType = header[0] + b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4 + b.partialMsg = true + } + + var msg FrontendMessage + switch b.msgType { + case 'B': + msg = &b.bind + case 'C': + msg = &b._close + case 'D': + msg = &b.describe + case 'E': + msg = &b.execute + case 'F': + msg = &b.functionCall + case 'f': + msg = &b.copyFail + case 'd': + msg = &b.copyData + case 'c': + msg = &b.copyDone + case 'H': + msg = &b.flush + case 'P': + msg = &b.parse + case 'p': + switch b.authType { + case AuthTypeSASL: + msg = &SASLInitialResponse{} + case AuthTypeSASLContinue: + msg = &SASLResponse{} + case AuthTypeSASLFinal: + msg = &SASLResponse{} + case AuthTypeGSS, AuthTypeGSSCont: + msg = &GSSResponse{} + case AuthTypeCleartextPassword, AuthTypeMD5Password: + fallthrough + default: + // to maintain backwards compatability + msg = &PasswordMessage{} + } + case 'Q': + msg = &b.query + case 'S': + msg = &b.sync + case 'X': + msg = &b.terminate + default: + return nil, fmt.Errorf("unknown message type: %c", b.msgType) + } + + msgBody, err := b.cr.Next(b.bodyLen) + if err != nil { + return nil, translateEOFtoErrUnexpectedEOF(err) + } + + b.partialMsg = false + + err = msg.Decode(msgBody) + if err != nil { + return nil, err + } + + if b.tracer != nil { + b.tracer.traceMessage('F', int32(5+len(msgBody)), msg) + } + + return msg, nil +} + +// SetAuthType sets the authentication type in the backend. +// Since multiple message types can start with 'p', SetAuthType allows +// contextual identification of FrontendMessages. For example, in the +// PG message flow documentation for PasswordMessage: +// +// Byte1('p') +// +// Identifies the message as a password response. Note that this is also used for +// GSSAPI, SSPI and SASL response messages. The exact message type can be deduced from +// the context. +// +// Since the Frontend does not know about the state of a backend, it is important +// to call SetAuthType() after an authentication request is received by the Frontend. +func (b *Backend) SetAuthType(authType uint32) error { + switch authType { + case AuthTypeOk, + AuthTypeCleartextPassword, + AuthTypeMD5Password, + AuthTypeSCMCreds, + AuthTypeGSS, + AuthTypeGSSCont, + AuthTypeSSPI, + AuthTypeSASL, + AuthTypeSASLContinue, + AuthTypeSASLFinal: + b.authType = authType + default: + return fmt.Errorf("authType not recognized: %d", authType) + } + + return nil +} diff --git a/pgproto3/backend_key_data.go b/pgproto3/backend_key_data.go new file mode 100644 index 00000000..12c60817 --- /dev/null +++ b/pgproto3/backend_key_data.go @@ -0,0 +1,51 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type BackendKeyData struct { + ProcessID uint32 + SecretKey uint32 +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*BackendKeyData) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *BackendKeyData) Decode(src []byte) error { + if len(src) != 8 { + return &invalidMessageLenErr{messageType: "BackendKeyData", expectedLen: 8, actualLen: len(src)} + } + + dst.ProcessID = binary.BigEndian.Uint32(src[:4]) + dst.SecretKey = binary.BigEndian.Uint32(src[4:]) + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *BackendKeyData) Encode(dst []byte) []byte { + dst = append(dst, 'K') + dst = pgio.AppendUint32(dst, 12) + dst = pgio.AppendUint32(dst, src.ProcessID) + dst = pgio.AppendUint32(dst, src.SecretKey) + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src BackendKeyData) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ProcessID uint32 + SecretKey uint32 + }{ + Type: "BackendKeyData", + ProcessID: src.ProcessID, + SecretKey: src.SecretKey, + }) +} diff --git a/pgproto3/backend_test.go b/pgproto3/backend_test.go new file mode 100644 index 00000000..596245dd --- /dev/null +++ b/pgproto3/backend_test.go @@ -0,0 +1,122 @@ +package pgproto3_test + +import ( + "io" + "testing" + + "github.com/jackc/pgx/v5/internal/pgio" + "github.com/jackc/pgx/v5/pgproto3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBackendReceiveInterrupted(t *testing.T) { + t.Parallel() + + server := &interruptReader{} + server.push([]byte{'Q', 0, 0, 0, 6}) + + backend := pgproto3.NewBackend(server, nil) + + msg, err := backend.Receive() + if err == nil { + t.Fatal("expected err") + } + if msg != nil { + t.Fatalf("did not expect msg, but %v", msg) + } + + server.push([]byte{'I', 0}) + + msg, err = backend.Receive() + if err != nil { + t.Fatal(err) + } + if msg, ok := msg.(*pgproto3.Query); !ok || msg.String != "I" { + t.Fatalf("unexpected msg: %v", msg) + } +} + +func TestBackendReceiveUnexpectedEOF(t *testing.T) { + t.Parallel() + + server := &interruptReader{} + server.push([]byte{'Q', 0, 0, 0, 6}) + + backend := pgproto3.NewBackend(server, nil) + + // Receive regular msg + msg, err := backend.Receive() + assert.Nil(t, msg) + assert.Equal(t, io.ErrUnexpectedEOF, err) + + // Receive StartupMessage msg + dst := []byte{} + dst = pgio.AppendUint32(dst, 1000) // tell the backend we expect 1000 bytes to be read + dst = pgio.AppendUint32(dst, 1) // only send 1 byte + server.push(dst) + + msg, err = backend.ReceiveStartupMessage() + assert.Nil(t, msg) + assert.Equal(t, io.ErrUnexpectedEOF, err) +} + +func TestStartupMessage(t *testing.T) { + t.Parallel() + + t.Run("valid StartupMessage", func(t *testing.T) { + want := &pgproto3.StartupMessage{ + ProtocolVersion: pgproto3.ProtocolVersionNumber, + Parameters: map[string]string{ + "username": "tester", + }, + } + dst := []byte{} + dst = want.Encode(dst) + + server := &interruptReader{} + server.push(dst) + + backend := pgproto3.NewBackend(server, nil) + + msg, err := backend.ReceiveStartupMessage() + require.NoError(t, err) + require.Equal(t, want, msg) + }) + + t.Run("invalid packet length", func(t *testing.T) { + wantErr := "invalid length of startup packet" + tests := []struct { + name string + packetLen uint32 + }{ + { + name: "large packet length", + // Since the StartupMessage contains the "Length of message contents + // in bytes, including self", the max startup packet length is actually + // 10000+4. Therefore, let's go past the limit with 10005 + packetLen: 10005, + }, + { + name: "short packet length", + packetLen: 3, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := &interruptReader{} + dst := []byte{} + dst = pgio.AppendUint32(dst, tt.packetLen) + dst = pgio.AppendUint32(dst, pgproto3.ProtocolVersionNumber) + server.push(dst) + + backend := pgproto3.NewBackend(server, nil) + + msg, err := backend.ReceiveStartupMessage() + require.Error(t, err) + require.Nil(t, msg) + require.Contains(t, err.Error(), wantErr) + }) + } + }) +} diff --git a/pgproto3/big_endian.go b/pgproto3/big_endian.go new file mode 100644 index 00000000..f7bdb97e --- /dev/null +++ b/pgproto3/big_endian.go @@ -0,0 +1,37 @@ +package pgproto3 + +import ( + "encoding/binary" +) + +type BigEndianBuf [8]byte + +func (b BigEndianBuf) Int16(n int16) []byte { + buf := b[0:2] + binary.BigEndian.PutUint16(buf, uint16(n)) + return buf +} + +func (b BigEndianBuf) Uint16(n uint16) []byte { + buf := b[0:2] + binary.BigEndian.PutUint16(buf, n) + return buf +} + +func (b BigEndianBuf) Int32(n int32) []byte { + buf := b[0:4] + binary.BigEndian.PutUint32(buf, uint32(n)) + return buf +} + +func (b BigEndianBuf) Uint32(n uint32) []byte { + buf := b[0:4] + binary.BigEndian.PutUint32(buf, n) + return buf +} + +func (b BigEndianBuf) Int64(n int64) []byte { + buf := b[0:8] + binary.BigEndian.PutUint64(buf, uint64(n)) + return buf +} diff --git a/pgproto3/bind.go b/pgproto3/bind.go new file mode 100644 index 00000000..fdd2d3b8 --- /dev/null +++ b/pgproto3/bind.go @@ -0,0 +1,216 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/hex" + "encoding/json" + "fmt" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type Bind struct { + DestinationPortal string + PreparedStatement string + ParameterFormatCodes []int16 + Parameters [][]byte + ResultFormatCodes []int16 +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*Bind) Frontend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *Bind) Decode(src []byte) error { + *dst = Bind{} + + idx := bytes.IndexByte(src, 0) + if idx < 0 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + dst.DestinationPortal = string(src[:idx]) + rp := idx + 1 + + idx = bytes.IndexByte(src[rp:], 0) + if idx < 0 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + dst.PreparedStatement = string(src[rp : rp+idx]) + rp += idx + 1 + + if len(src[rp:]) < 2 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + parameterFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + + if parameterFormatCodeCount > 0 { + dst.ParameterFormatCodes = make([]int16, parameterFormatCodeCount) + + if len(src[rp:]) < len(dst.ParameterFormatCodes)*2 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + for i := 0; i < parameterFormatCodeCount; i++ { + dst.ParameterFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + } + } + + if len(src[rp:]) < 2 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + parameterCount := int(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + + if parameterCount > 0 { + dst.Parameters = make([][]byte, parameterCount) + + for i := 0; i < parameterCount; i++ { + if len(src[rp:]) < 4 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + + msgSize := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + // null + if msgSize == -1 { + continue + } + + if len(src[rp:]) < msgSize { + return &invalidMessageFormatErr{messageType: "Bind"} + } + + dst.Parameters[i] = src[rp : rp+msgSize] + rp += msgSize + } + } + + if len(src[rp:]) < 2 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + resultFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + + dst.ResultFormatCodes = make([]int16, resultFormatCodeCount) + if len(src[rp:]) < len(dst.ResultFormatCodes)*2 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + for i := 0; i < resultFormatCodeCount; i++ { + dst.ResultFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *Bind) Encode(dst []byte) []byte { + dst = append(dst, 'B') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = append(dst, src.DestinationPortal...) + dst = append(dst, 0) + dst = append(dst, src.PreparedStatement...) + dst = append(dst, 0) + + dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes))) + for _, fc := range src.ParameterFormatCodes { + dst = pgio.AppendInt16(dst, fc) + } + + dst = pgio.AppendUint16(dst, uint16(len(src.Parameters))) + for _, p := range src.Parameters { + if p == nil { + dst = pgio.AppendInt32(dst, -1) + continue + } + + dst = pgio.AppendInt32(dst, int32(len(p))) + dst = append(dst, p...) + } + + dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes))) + for _, fc := range src.ResultFormatCodes { + dst = pgio.AppendInt16(dst, fc) + } + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src Bind) MarshalJSON() ([]byte, error) { + formattedParameters := make([]map[string]string, len(src.Parameters)) + for i, p := range src.Parameters { + if p == nil { + continue + } + + textFormat := true + if len(src.ParameterFormatCodes) == 1 { + textFormat = src.ParameterFormatCodes[0] == 0 + } else if len(src.ParameterFormatCodes) > 1 { + textFormat = src.ParameterFormatCodes[i] == 0 + } + + if textFormat { + formattedParameters[i] = map[string]string{"text": string(p)} + } else { + formattedParameters[i] = map[string]string{"binary": hex.EncodeToString(p)} + } + } + + return json.Marshal(struct { + Type string + DestinationPortal string + PreparedStatement string + ParameterFormatCodes []int16 + Parameters []map[string]string + ResultFormatCodes []int16 + }{ + Type: "Bind", + DestinationPortal: src.DestinationPortal, + PreparedStatement: src.PreparedStatement, + ParameterFormatCodes: src.ParameterFormatCodes, + Parameters: formattedParameters, + ResultFormatCodes: src.ResultFormatCodes, + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *Bind) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + DestinationPortal string + PreparedStatement string + ParameterFormatCodes []int16 + Parameters []map[string]string + ResultFormatCodes []int16 + } + err := json.Unmarshal(data, &msg) + if err != nil { + return err + } + dst.DestinationPortal = msg.DestinationPortal + dst.PreparedStatement = msg.PreparedStatement + dst.ParameterFormatCodes = msg.ParameterFormatCodes + dst.Parameters = make([][]byte, len(msg.Parameters)) + dst.ResultFormatCodes = msg.ResultFormatCodes + for n, parameter := range msg.Parameters { + dst.Parameters[n], err = getValueFromJSON(parameter) + if err != nil { + return fmt.Errorf("cannot get param %d: %w", n, err) + } + } + return nil +} diff --git a/pgproto3/bind_complete.go b/pgproto3/bind_complete.go new file mode 100644 index 00000000..3be256c8 --- /dev/null +++ b/pgproto3/bind_complete.go @@ -0,0 +1,34 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type BindComplete struct{} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*BindComplete) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *BindComplete) Decode(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "BindComplete", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *BindComplete) Encode(dst []byte) []byte { + return append(dst, '2', 0, 0, 0, 4) +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src BindComplete) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "BindComplete", + }) +} diff --git a/pgproto3/cancel_request.go b/pgproto3/cancel_request.go new file mode 100644 index 00000000..8fcf8217 --- /dev/null +++ b/pgproto3/cancel_request.go @@ -0,0 +1,58 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +const cancelRequestCode = 80877102 + +type CancelRequest struct { + ProcessID uint32 + SecretKey uint32 +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*CancelRequest) Frontend() {} + +func (dst *CancelRequest) Decode(src []byte) error { + if len(src) != 12 { + return errors.New("bad cancel request size") + } + + requestCode := binary.BigEndian.Uint32(src) + + if requestCode != cancelRequestCode { + return errors.New("bad cancel request code") + } + + dst.ProcessID = binary.BigEndian.Uint32(src[4:]) + dst.SecretKey = binary.BigEndian.Uint32(src[8:]) + + return nil +} + +// Encode encodes src into dst. dst will include the 4 byte message length. +func (src *CancelRequest) Encode(dst []byte) []byte { + dst = pgio.AppendInt32(dst, 16) + dst = pgio.AppendInt32(dst, cancelRequestCode) + dst = pgio.AppendUint32(dst, src.ProcessID) + dst = pgio.AppendUint32(dst, src.SecretKey) + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src CancelRequest) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ProcessID uint32 + SecretKey uint32 + }{ + Type: "CancelRequest", + ProcessID: src.ProcessID, + SecretKey: src.SecretKey, + }) +} diff --git a/pgproto3/chunkreader.go b/pgproto3/chunkreader.go new file mode 100644 index 00000000..3c35d0b1 --- /dev/null +++ b/pgproto3/chunkreader.go @@ -0,0 +1,90 @@ +package pgproto3 + +import ( + "io" + + "github.com/jackc/pgx/v5/internal/iobufpool" +) + +// chunkReader is a io.Reader wrapper that minimizes IO reads and memory allocations. It allocates memory in chunks and +// will read as much as will fit in the current buffer in a single call regardless of how large a read is actually +// requested. The memory returned via Next is only valid until the next call to Next. +// +// This is roughly equivalent to a bufio.Reader that only uses Peek and Discard to never copy bytes. +type chunkReader struct { + r io.Reader + + buf []byte + rp, wp int // buf read position and write position + + minBufSize int +} + +// newChunkReader creates and returns a new chunkReader for r with default configuration. If minBufSize is <= 0 it uses +// a default value. +func newChunkReader(r io.Reader, minBufSize int) *chunkReader { + if minBufSize <= 0 { + // By historical reasons Postgres currently has 8KB send buffer inside, + // so here we want to have at least the same size buffer. + // @see https://github.com/postgres/postgres/blob/249d64999615802752940e017ee5166e726bc7cd/src/backend/libpq/pqcomm.c#L134 + // @see https://www.postgresql.org/message-id/0cdc5485-cb3c-5e16-4a46-e3b2f7a41322%40ya.ru + // + // In addition, testing has found no benefit of any larger buffer. + minBufSize = 8192 + } + + return &chunkReader{ + r: r, + minBufSize: minBufSize, + buf: iobufpool.Get(minBufSize), + } +} + +// Next returns buf filled with the next n bytes. buf is only valid until next call of Next. If an error occurs, buf +// will be nil. +func (r *chunkReader) Next(n int) (buf []byte, err error) { + // Reset the buffer if it is empty + if r.rp == r.wp { + if len(r.buf) != r.minBufSize { + iobufpool.Put(r.buf) + r.buf = iobufpool.Get(r.minBufSize) + } + r.rp = 0 + r.wp = 0 + } + + // n bytes already in buf + if (r.wp - r.rp) >= n { + buf = r.buf[r.rp : r.rp+n : r.rp+n] + r.rp += n + return buf, err + } + + // buf is smaller than requested number of bytes + if len(r.buf) < n { + bigBuf := iobufpool.Get(n) + r.wp = copy(bigBuf, r.buf[r.rp:r.wp]) + r.rp = 0 + iobufpool.Put(r.buf) + r.buf = bigBuf + } + + // buf is large enough, but need to shift filled area to start to make enough contiguous space + minReadCount := n - (r.wp - r.rp) + if (len(r.buf) - r.wp) < minReadCount { + r.wp = copy(r.buf, r.buf[r.rp:r.wp]) + r.rp = 0 + } + + // Read at least the required number of bytes from the underlying io.Reader + readBytesCount, err := io.ReadAtLeast(r.r, r.buf[r.wp:], minReadCount) + r.wp += readBytesCount + // fmt.Println("read", n) + if err != nil { + return nil, err + } + + buf = r.buf[r.rp : r.rp+n : r.rp+n] + r.rp += n + return buf, nil +} diff --git a/pgproto3/chunkreader_test.go b/pgproto3/chunkreader_test.go new file mode 100644 index 00000000..41c8ce65 --- /dev/null +++ b/pgproto3/chunkreader_test.go @@ -0,0 +1,75 @@ +package pgproto3 + +import ( + "bytes" + "math/rand" + "testing" +) + +func TestChunkReaderNextDoesNotReadIfAlreadyBuffered(t *testing.T) { + server := &bytes.Buffer{} + r := newChunkReader(server, 4) + + src := []byte{1, 2, 3, 4} + server.Write(src) + + n1, err := r.Next(2) + if err != nil { + t.Fatal(err) + } + if bytes.Compare(n1, src[0:2]) != 0 { + t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:2], n1) + } + + n2, err := r.Next(2) + if err != nil { + t.Fatal(err) + } + if bytes.Compare(n2, src[2:4]) != 0 { + t.Fatalf("Expected read bytes to be %v, but they were %v", src[2:4], n2) + } + + if bytes.Compare(r.buf[:len(src)], src) != 0 { + t.Fatalf("Expected r.buf to be %v, but it was %v", src, r.buf) + } + + _, err = r.Next(0) // Trigger the buffer reset. + if err != nil { + t.Fatal(err) + } + + if r.rp != 0 { + t.Fatalf("Expected r.rp to be %v, but it was %v", 0, r.rp) + } + if r.wp != 0 { + t.Fatalf("Expected r.wp to be %v, but it was %v", 0, r.wp) + } +} + +type randomReader struct { + rnd *rand.Rand +} + +// Read reads a random number of random bytes. +func (r *randomReader) Read(p []byte) (n int, err error) { + n = r.rnd.Intn(len(p) + 1) + return r.rnd.Read(p[:n]) +} + +func TestChunkReaderNextFuzz(t *testing.T) { + rr := &randomReader{rnd: rand.New(rand.NewSource(1))} + r := newChunkReader(rr, 8192) + + randomSizes := rand.New(rand.NewSource(0)) + + for i := 0; i < 100000; i++ { + size := randomSizes.Intn(16384) + 1 + buf, err := r.Next(size) + if err != nil { + t.Fatal(err) + } + if len(buf) != size { + t.Fatalf("Expected to get %v bytes but got %v bytes", size, len(buf)) + } + } +} diff --git a/pgproto3/close.go b/pgproto3/close.go new file mode 100644 index 00000000..f99b5943 --- /dev/null +++ b/pgproto3/close.go @@ -0,0 +1,89 @@ +package pgproto3 + +import ( + "bytes" + "encoding/json" + "errors" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type Close struct { + ObjectType byte // 'S' = prepared statement, 'P' = portal + Name string +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*Close) Frontend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *Close) Decode(src []byte) error { + if len(src) < 2 { + return &invalidMessageFormatErr{messageType: "Close"} + } + + dst.ObjectType = src[0] + rp := 1 + + idx := bytes.IndexByte(src[rp:], 0) + if idx != len(src[rp:])-1 { + return &invalidMessageFormatErr{messageType: "Close"} + } + + dst.Name = string(src[rp : len(src)-1]) + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *Close) Encode(dst []byte) []byte { + dst = append(dst, 'C') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = append(dst, src.ObjectType) + dst = append(dst, src.Name...) + dst = append(dst, 0) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src Close) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ObjectType string + Name string + }{ + Type: "Close", + ObjectType: string(src.ObjectType), + Name: src.Name, + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *Close) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + ObjectType string + Name string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + if len(msg.ObjectType) != 1 { + return errors.New("invalid length for Close.ObjectType") + } + + dst.ObjectType = byte(msg.ObjectType[0]) + dst.Name = msg.Name + return nil +} diff --git a/pgproto3/close_complete.go b/pgproto3/close_complete.go new file mode 100644 index 00000000..1d7b8f08 --- /dev/null +++ b/pgproto3/close_complete.go @@ -0,0 +1,34 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type CloseComplete struct{} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*CloseComplete) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *CloseComplete) Decode(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "CloseComplete", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *CloseComplete) Encode(dst []byte) []byte { + return append(dst, '3', 0, 0, 0, 4) +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src CloseComplete) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "CloseComplete", + }) +} diff --git a/pgproto3/command_complete.go b/pgproto3/command_complete.go new file mode 100644 index 00000000..814027ca --- /dev/null +++ b/pgproto3/command_complete.go @@ -0,0 +1,74 @@ +package pgproto3 + +import ( + "bytes" + "encoding/json" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type CommandComplete struct { + CommandTag []byte +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*CommandComplete) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *CommandComplete) Decode(src []byte) error { + idx := bytes.IndexByte(src, 0) + if idx == -1 { + return &invalidMessageFormatErr{messageType: "CommandComplete", details: "unterminated string"} + } + if idx != len(src)-1 { + return &invalidMessageFormatErr{messageType: "CommandComplete", details: "string terminated too early"} + } + + dst.CommandTag = src[:idx] + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *CommandComplete) Encode(dst []byte) []byte { + dst = append(dst, 'C') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = append(dst, src.CommandTag...) + dst = append(dst, 0) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src CommandComplete) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + CommandTag string + }{ + Type: "CommandComplete", + CommandTag: string(src.CommandTag), + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *CommandComplete) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + CommandTag string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.CommandTag = []byte(msg.CommandTag) + return nil +} diff --git a/pgproto3/copy_both_response.go b/pgproto3/copy_both_response.go new file mode 100644 index 00000000..8840a89e --- /dev/null +++ b/pgproto3/copy_both_response.go @@ -0,0 +1,95 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type CopyBothResponse struct { + OverallFormat byte + ColumnFormatCodes []uint16 +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*CopyBothResponse) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *CopyBothResponse) Decode(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 3 { + return &invalidMessageFormatErr{messageType: "CopyBothResponse"} + } + + overallFormat := buf.Next(1)[0] + + columnCount := int(binary.BigEndian.Uint16(buf.Next(2))) + if buf.Len() != columnCount*2 { + return &invalidMessageFormatErr{messageType: "CopyBothResponse"} + } + + columnFormatCodes := make([]uint16, columnCount) + for i := 0; i < columnCount; i++ { + columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2)) + } + + *dst = CopyBothResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes} + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *CopyBothResponse) Encode(dst []byte) []byte { + dst = append(dst, 'W') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + dst = append(dst, src.OverallFormat) + dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) + for _, fc := range src.ColumnFormatCodes { + dst = pgio.AppendUint16(dst, fc) + } + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src CopyBothResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ColumnFormatCodes []uint16 + }{ + Type: "CopyBothResponse", + ColumnFormatCodes: src.ColumnFormatCodes, + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *CopyBothResponse) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + OverallFormat string + ColumnFormatCodes []uint16 + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + if len(msg.OverallFormat) != 1 { + return errors.New("invalid length for CopyBothResponse.OverallFormat") + } + + dst.OverallFormat = msg.OverallFormat[0] + dst.ColumnFormatCodes = msg.ColumnFormatCodes + return nil +} diff --git a/pgproto3/copy_both_response_test.go b/pgproto3/copy_both_response_test.go new file mode 100644 index 00000000..4437de1d --- /dev/null +++ b/pgproto3/copy_both_response_test.go @@ -0,0 +1,18 @@ +package pgproto3_test + +import ( + "testing" + + "github.com/jackc/pgx/v5/pgproto3" + "github.com/stretchr/testify/assert" +) + +func TestEncodeDecode(t *testing.T) { + srcBytes := []byte{'W', 0x00, 0x00, 0x00, 0x0b, 0x01, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01} + dstResp := pgproto3.CopyBothResponse{} + err := dstResp.Decode(srcBytes[5:]) + assert.NoError(t, err, "No errors on decode") + dstBytes := []byte{} + dstBytes = dstResp.Encode(dstBytes) + assert.EqualValues(t, srcBytes, dstBytes, "Expecting src & dest bytes to match") +} diff --git a/pgproto3/copy_data.go b/pgproto3/copy_data.go new file mode 100644 index 00000000..59e3dd94 --- /dev/null +++ b/pgproto3/copy_data.go @@ -0,0 +1,62 @@ +package pgproto3 + +import ( + "encoding/hex" + "encoding/json" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type CopyData struct { + Data []byte +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*CopyData) Backend() {} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*CopyData) Frontend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *CopyData) Decode(src []byte) error { + dst.Data = src + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *CopyData) Encode(dst []byte) []byte { + dst = append(dst, 'd') + dst = pgio.AppendInt32(dst, int32(4+len(src.Data))) + dst = append(dst, src.Data...) + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src CopyData) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data string + }{ + Type: "CopyData", + Data: hex.EncodeToString(src.Data), + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *CopyData) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Data string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.Data = []byte(msg.Data) + return nil +} diff --git a/pgproto3/copy_done.go b/pgproto3/copy_done.go new file mode 100644 index 00000000..0e13282b --- /dev/null +++ b/pgproto3/copy_done.go @@ -0,0 +1,38 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type CopyDone struct { +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*CopyDone) Backend() {} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*CopyDone) Frontend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *CopyDone) Decode(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "CopyDone", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *CopyDone) Encode(dst []byte) []byte { + return append(dst, 'c', 0, 0, 0, 4) +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src CopyDone) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "CopyDone", + }) +} diff --git a/pgproto3/copy_fail.go b/pgproto3/copy_fail.go new file mode 100644 index 00000000..0041bbb1 --- /dev/null +++ b/pgproto3/copy_fail.go @@ -0,0 +1,53 @@ +package pgproto3 + +import ( + "bytes" + "encoding/json" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type CopyFail struct { + Message string +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*CopyFail) Frontend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *CopyFail) Decode(src []byte) error { + idx := bytes.IndexByte(src, 0) + if idx != len(src)-1 { + return &invalidMessageFormatErr{messageType: "CopyFail"} + } + + dst.Message = string(src[:idx]) + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *CopyFail) Encode(dst []byte) []byte { + dst = append(dst, 'f') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = append(dst, src.Message...) + dst = append(dst, 0) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src CopyFail) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Message string + }{ + Type: "CopyFail", + Message: src.Message, + }) +} diff --git a/pgproto3/copy_in_response.go b/pgproto3/copy_in_response.go new file mode 100644 index 00000000..4584f7df --- /dev/null +++ b/pgproto3/copy_in_response.go @@ -0,0 +1,96 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type CopyInResponse struct { + OverallFormat byte + ColumnFormatCodes []uint16 +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*CopyInResponse) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *CopyInResponse) Decode(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 3 { + return &invalidMessageFormatErr{messageType: "CopyInResponse"} + } + + overallFormat := buf.Next(1)[0] + + columnCount := int(binary.BigEndian.Uint16(buf.Next(2))) + if buf.Len() != columnCount*2 { + return &invalidMessageFormatErr{messageType: "CopyInResponse"} + } + + columnFormatCodes := make([]uint16, columnCount) + for i := 0; i < columnCount; i++ { + columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2)) + } + + *dst = CopyInResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes} + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *CopyInResponse) Encode(dst []byte) []byte { + dst = append(dst, 'G') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = append(dst, src.OverallFormat) + dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) + for _, fc := range src.ColumnFormatCodes { + dst = pgio.AppendUint16(dst, fc) + } + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src CopyInResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ColumnFormatCodes []uint16 + }{ + Type: "CopyInResponse", + ColumnFormatCodes: src.ColumnFormatCodes, + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *CopyInResponse) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + OverallFormat string + ColumnFormatCodes []uint16 + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + if len(msg.OverallFormat) != 1 { + return errors.New("invalid length for CopyInResponse.OverallFormat") + } + + dst.OverallFormat = msg.OverallFormat[0] + dst.ColumnFormatCodes = msg.ColumnFormatCodes + return nil +} diff --git a/pgproto3/copy_out_response.go b/pgproto3/copy_out_response.go new file mode 100644 index 00000000..3175c6a4 --- /dev/null +++ b/pgproto3/copy_out_response.go @@ -0,0 +1,96 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type CopyOutResponse struct { + OverallFormat byte + ColumnFormatCodes []uint16 +} + +func (*CopyOutResponse) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *CopyOutResponse) Decode(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 3 { + return &invalidMessageFormatErr{messageType: "CopyOutResponse"} + } + + overallFormat := buf.Next(1)[0] + + columnCount := int(binary.BigEndian.Uint16(buf.Next(2))) + if buf.Len() != columnCount*2 { + return &invalidMessageFormatErr{messageType: "CopyOutResponse"} + } + + columnFormatCodes := make([]uint16, columnCount) + for i := 0; i < columnCount; i++ { + columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2)) + } + + *dst = CopyOutResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes} + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *CopyOutResponse) Encode(dst []byte) []byte { + dst = append(dst, 'H') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = append(dst, src.OverallFormat) + + dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) + for _, fc := range src.ColumnFormatCodes { + dst = pgio.AppendUint16(dst, fc) + } + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src CopyOutResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ColumnFormatCodes []uint16 + }{ + Type: "CopyOutResponse", + ColumnFormatCodes: src.ColumnFormatCodes, + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *CopyOutResponse) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + OverallFormat string + ColumnFormatCodes []uint16 + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + if len(msg.OverallFormat) != 1 { + return errors.New("invalid length for CopyOutResponse.OverallFormat") + } + + dst.OverallFormat = msg.OverallFormat[0] + dst.ColumnFormatCodes = msg.ColumnFormatCodes + return nil +} diff --git a/pgproto3/data_row.go b/pgproto3/data_row.go new file mode 100644 index 00000000..4de77977 --- /dev/null +++ b/pgproto3/data_row.go @@ -0,0 +1,142 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/hex" + "encoding/json" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type DataRow struct { + Values [][]byte +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*DataRow) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *DataRow) Decode(src []byte) error { + if len(src) < 2 { + return &invalidMessageFormatErr{messageType: "DataRow"} + } + rp := 0 + fieldCount := int(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + + // If the capacity of the values slice is too small OR substantially too + // large reallocate. This is too avoid one row with many columns from + // permanently allocating memory. + if cap(dst.Values) < fieldCount || cap(dst.Values)-fieldCount > 32 { + newCap := 32 + if newCap < fieldCount { + newCap = fieldCount + } + dst.Values = make([][]byte, fieldCount, newCap) + } else { + dst.Values = dst.Values[:fieldCount] + } + + for i := 0; i < fieldCount; i++ { + if len(src[rp:]) < 4 { + return &invalidMessageFormatErr{messageType: "DataRow"} + } + + valueLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + // null + if valueLen == -1 { + dst.Values[i] = nil + } else { + if len(src[rp:]) < valueLen || valueLen < 0 { + return &invalidMessageFormatErr{messageType: "DataRow"} + } + + dst.Values[i] = src[rp : rp+valueLen : rp+valueLen] + rp += valueLen + } + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *DataRow) Encode(dst []byte) []byte { + dst = append(dst, 'D') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = pgio.AppendUint16(dst, uint16(len(src.Values))) + for _, v := range src.Values { + if v == nil { + dst = pgio.AppendInt32(dst, -1) + continue + } + + dst = pgio.AppendInt32(dst, int32(len(v))) + dst = append(dst, v...) + } + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src DataRow) MarshalJSON() ([]byte, error) { + formattedValues := make([]map[string]string, len(src.Values)) + for i, v := range src.Values { + if v == nil { + continue + } + + var hasNonPrintable bool + for _, b := range v { + if b < 32 { + hasNonPrintable = true + break + } + } + + if hasNonPrintable { + formattedValues[i] = map[string]string{"binary": hex.EncodeToString(v)} + } else { + formattedValues[i] = map[string]string{"text": string(v)} + } + } + + return json.Marshal(struct { + Type string + Values []map[string]string + }{ + Type: "DataRow", + Values: formattedValues, + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *DataRow) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Values []map[string]string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.Values = make([][]byte, len(msg.Values)) + for n, parameter := range msg.Values { + var err error + dst.Values[n], err = getValueFromJSON(parameter) + if err != nil { + return err + } + } + return nil +} diff --git a/pgproto3/describe.go b/pgproto3/describe.go new file mode 100644 index 00000000..f131d1f4 --- /dev/null +++ b/pgproto3/describe.go @@ -0,0 +1,88 @@ +package pgproto3 + +import ( + "bytes" + "encoding/json" + "errors" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type Describe struct { + ObjectType byte // 'S' = prepared statement, 'P' = portal + Name string +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*Describe) Frontend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *Describe) Decode(src []byte) error { + if len(src) < 2 { + return &invalidMessageFormatErr{messageType: "Describe"} + } + + dst.ObjectType = src[0] + rp := 1 + + idx := bytes.IndexByte(src[rp:], 0) + if idx != len(src[rp:])-1 { + return &invalidMessageFormatErr{messageType: "Describe"} + } + + dst.Name = string(src[rp : len(src)-1]) + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *Describe) Encode(dst []byte) []byte { + dst = append(dst, 'D') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = append(dst, src.ObjectType) + dst = append(dst, src.Name...) + dst = append(dst, 0) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src Describe) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ObjectType string + Name string + }{ + Type: "Describe", + ObjectType: string(src.ObjectType), + Name: src.Name, + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *Describe) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + ObjectType string + Name string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + if len(msg.ObjectType) != 1 { + return errors.New("invalid length for Describe.ObjectType") + } + + dst.ObjectType = byte(msg.ObjectType[0]) + dst.Name = msg.Name + return nil +} diff --git a/pgproto3/doc.go b/pgproto3/doc.go new file mode 100644 index 00000000..e0e1cf87 --- /dev/null +++ b/pgproto3/doc.go @@ -0,0 +1,11 @@ +// Package pgproto3 is a encoder and decoder of the PostgreSQL wire protocol version 3. +// +// The primary interfaces are Frontend and Backend. They correspond to a client and server respectively. Messages are +// sent with Send (or a specialized Send variant). Messages are automatically bufferred to minimize small writes. Call +// Flush to ensure a message has actually been sent. +// +// The Trace method of Frontend and Backend can be used to examine the wire-level message traffic. It outputs in a +// similar format to the PQtrace function in libpq. +// +// See https://www.postgresql.org/docs/current/protocol-message-formats.html for meanings of the different messages. +package pgproto3 diff --git a/pgproto3/empty_query_response.go b/pgproto3/empty_query_response.go new file mode 100644 index 00000000..2b85e744 --- /dev/null +++ b/pgproto3/empty_query_response.go @@ -0,0 +1,34 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type EmptyQueryResponse struct{} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*EmptyQueryResponse) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *EmptyQueryResponse) Decode(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "EmptyQueryResponse", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *EmptyQueryResponse) Encode(dst []byte) []byte { + return append(dst, 'I', 0, 0, 0, 4) +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src EmptyQueryResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "EmptyQueryResponse", + }) +} diff --git a/pgproto3/error_response.go b/pgproto3/error_response.go new file mode 100644 index 00000000..ec51e019 --- /dev/null +++ b/pgproto3/error_response.go @@ -0,0 +1,334 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" + "strconv" +) + +type ErrorResponse struct { + Severity string + SeverityUnlocalized string // only in 9.6 and greater + Code string + Message string + Detail string + Hint string + Position int32 + InternalPosition int32 + InternalQuery string + Where string + SchemaName string + TableName string + ColumnName string + DataTypeName string + ConstraintName string + File string + Line int32 + Routine string + + UnknownFields map[byte]string +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*ErrorResponse) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *ErrorResponse) Decode(src []byte) error { + *dst = ErrorResponse{} + + buf := bytes.NewBuffer(src) + + for { + k, err := buf.ReadByte() + if err != nil { + return err + } + if k == 0 { + break + } + + vb, err := buf.ReadBytes(0) + if err != nil { + return err + } + v := string(vb[:len(vb)-1]) + + switch k { + case 'S': + dst.Severity = v + case 'V': + dst.SeverityUnlocalized = v + case 'C': + dst.Code = v + case 'M': + dst.Message = v + case 'D': + dst.Detail = v + case 'H': + dst.Hint = v + case 'P': + s := v + n, _ := strconv.ParseInt(s, 10, 32) + dst.Position = int32(n) + case 'p': + s := v + n, _ := strconv.ParseInt(s, 10, 32) + dst.InternalPosition = int32(n) + case 'q': + dst.InternalQuery = v + case 'W': + dst.Where = v + case 's': + dst.SchemaName = v + case 't': + dst.TableName = v + case 'c': + dst.ColumnName = v + case 'd': + dst.DataTypeName = v + case 'n': + dst.ConstraintName = v + case 'F': + dst.File = v + case 'L': + s := v + n, _ := strconv.ParseInt(s, 10, 32) + dst.Line = int32(n) + case 'R': + dst.Routine = v + + default: + if dst.UnknownFields == nil { + dst.UnknownFields = make(map[byte]string) + } + dst.UnknownFields[k] = v + } + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *ErrorResponse) Encode(dst []byte) []byte { + return append(dst, src.marshalBinary('E')...) +} + +func (src *ErrorResponse) marshalBinary(typeByte byte) []byte { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte(typeByte) + buf.Write(bigEndian.Uint32(0)) + + if src.Severity != "" { + buf.WriteByte('S') + buf.WriteString(src.Severity) + buf.WriteByte(0) + } + if src.SeverityUnlocalized != "" { + buf.WriteByte('V') + buf.WriteString(src.SeverityUnlocalized) + buf.WriteByte(0) + } + if src.Code != "" { + buf.WriteByte('C') + buf.WriteString(src.Code) + buf.WriteByte(0) + } + if src.Message != "" { + buf.WriteByte('M') + buf.WriteString(src.Message) + buf.WriteByte(0) + } + if src.Detail != "" { + buf.WriteByte('D') + buf.WriteString(src.Detail) + buf.WriteByte(0) + } + if src.Hint != "" { + buf.WriteByte('H') + buf.WriteString(src.Hint) + buf.WriteByte(0) + } + if src.Position != 0 { + buf.WriteByte('P') + buf.WriteString(strconv.Itoa(int(src.Position))) + buf.WriteByte(0) + } + if src.InternalPosition != 0 { + buf.WriteByte('p') + buf.WriteString(strconv.Itoa(int(src.InternalPosition))) + buf.WriteByte(0) + } + if src.InternalQuery != "" { + buf.WriteByte('q') + buf.WriteString(src.InternalQuery) + buf.WriteByte(0) + } + if src.Where != "" { + buf.WriteByte('W') + buf.WriteString(src.Where) + buf.WriteByte(0) + } + if src.SchemaName != "" { + buf.WriteByte('s') + buf.WriteString(src.SchemaName) + buf.WriteByte(0) + } + if src.TableName != "" { + buf.WriteByte('t') + buf.WriteString(src.TableName) + buf.WriteByte(0) + } + if src.ColumnName != "" { + buf.WriteByte('c') + buf.WriteString(src.ColumnName) + buf.WriteByte(0) + } + if src.DataTypeName != "" { + buf.WriteByte('d') + buf.WriteString(src.DataTypeName) + buf.WriteByte(0) + } + if src.ConstraintName != "" { + buf.WriteByte('n') + buf.WriteString(src.ConstraintName) + buf.WriteByte(0) + } + if src.File != "" { + buf.WriteByte('F') + buf.WriteString(src.File) + buf.WriteByte(0) + } + if src.Line != 0 { + buf.WriteByte('L') + buf.WriteString(strconv.Itoa(int(src.Line))) + buf.WriteByte(0) + } + if src.Routine != "" { + buf.WriteByte('R') + buf.WriteString(src.Routine) + buf.WriteByte(0) + } + + for k, v := range src.UnknownFields { + buf.WriteByte(k) + buf.WriteByte(0) + buf.WriteString(v) + buf.WriteByte(0) + } + + buf.WriteByte(0) + + binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + + return buf.Bytes() +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src ErrorResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Severity string + SeverityUnlocalized string // only in 9.6 and greater + Code string + Message string + Detail string + Hint string + Position int32 + InternalPosition int32 + InternalQuery string + Where string + SchemaName string + TableName string + ColumnName string + DataTypeName string + ConstraintName string + File string + Line int32 + Routine string + + UnknownFields map[byte]string + }{ + Type: "ErrorResponse", + Severity: src.Severity, + SeverityUnlocalized: src.SeverityUnlocalized, + Code: src.Code, + Message: src.Message, + Detail: src.Detail, + Hint: src.Hint, + Position: src.Position, + InternalPosition: src.InternalPosition, + InternalQuery: src.InternalQuery, + Where: src.Where, + SchemaName: src.SchemaName, + TableName: src.TableName, + ColumnName: src.ColumnName, + DataTypeName: src.DataTypeName, + ConstraintName: src.ConstraintName, + File: src.File, + Line: src.Line, + Routine: src.Routine, + UnknownFields: src.UnknownFields, + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *ErrorResponse) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Type string + Severity string + SeverityUnlocalized string // only in 9.6 and greater + Code string + Message string + Detail string + Hint string + Position int32 + InternalPosition int32 + InternalQuery string + Where string + SchemaName string + TableName string + ColumnName string + DataTypeName string + ConstraintName string + File string + Line int32 + Routine string + + UnknownFields map[byte]string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.Severity = msg.Severity + dst.SeverityUnlocalized = msg.SeverityUnlocalized + dst.Code = msg.Code + dst.Message = msg.Message + dst.Detail = msg.Detail + dst.Hint = msg.Hint + dst.Position = msg.Position + dst.InternalPosition = msg.InternalPosition + dst.InternalQuery = msg.InternalQuery + dst.Where = msg.Where + dst.SchemaName = msg.SchemaName + dst.TableName = msg.TableName + dst.ColumnName = msg.ColumnName + dst.DataTypeName = msg.DataTypeName + dst.ConstraintName = msg.ConstraintName + dst.File = msg.File + dst.Line = msg.Line + dst.Routine = msg.Routine + + dst.UnknownFields = msg.UnknownFields + + return nil +} diff --git a/pgproto3/example/pgfortune/README.md b/pgproto3/example/pgfortune/README.md new file mode 100644 index 00000000..c181c38a --- /dev/null +++ b/pgproto3/example/pgfortune/README.md @@ -0,0 +1,53 @@ +# pgfortune + +pgfortune is a mock PostgreSQL server that responds to every query with a fortune. + +## Installation + +Install `fortune` and `cowsay`. They should be available in any Unix package manager (apt, yum, brew, etc.) + +``` +go get -u github.com/jackc/pgproto3/example/pgfortune +``` + +## Usage + +``` +$ pgfortune +``` + +By default pgfortune listens on 127.0.0.1:15432 and responds to queries with `fortune | cowsay -f elephant`. These are +configurable with the `listen` and `response-command` arguments respectively. + +While `pgfortune` is running connect to it with `psql`. + +``` +$ psql -h 127.0.0.1 -p 15432 +Timing is on. +Null display is "∅". +Line style is unicode. +psql (11.5, server 0.0.0) +Type "help" for help. + +jack@127.0.0.1:15432 jack=# select foo; + fortune +───────────────────────────────────────────── + _________________________________________ ↵ + / Ships are safe in harbor, but they were \↵ + \ never meant to stay there. /↵ + ----------------------------------------- ↵ + \ /\ ___ /\ ↵ + \ // \/ \/ \\ ↵ + (( O O )) ↵ + \\ / \ // ↵ + \/ | | \/ ↵ + | | | | ↵ + | | | | ↵ + | o | ↵ + | | | | ↵ + |m| |m| ↵ + +(1 row) + +Time: 28.161 ms +``` diff --git a/pgproto3/example/pgfortune/main.go b/pgproto3/example/pgfortune/main.go new file mode 100644 index 00000000..0c25510b --- /dev/null +++ b/pgproto3/example/pgfortune/main.go @@ -0,0 +1,51 @@ +package main + +import ( + "flag" + "fmt" + "log" + "net" + "os" + "os/exec" +) + +var options struct { + listenAddress string + responseCommand string +} + +func main() { + flag.Usage = func() { + fmt.Fprintf(os.Stderr, "usage: %s [options]\n", os.Args[0]) + flag.PrintDefaults() + } + + flag.StringVar(&options.listenAddress, "listen", "127.0.0.1:15432", "Listen address") + flag.StringVar(&options.responseCommand, "response-command", "fortune | cowsay -f elephant", "Command to execute to generate query response") + flag.Parse() + + ln, err := net.Listen("tcp", options.listenAddress) + if err != nil { + log.Fatal(err) + } + log.Println("Listening on", ln.Addr()) + + for { + conn, err := ln.Accept() + if err != nil { + log.Fatal(err) + } + log.Println("Accepted connection from", conn.RemoteAddr()) + + b := NewPgFortuneBackend(conn, func() ([]byte, error) { + return exec.Command("sh", "-c", options.responseCommand).CombinedOutput() + }) + go func() { + err := b.Run() + if err != nil { + log.Println(err) + } + log.Println("Closed connection from", conn.RemoteAddr()) + }() + } +} diff --git a/pgproto3/example/pgfortune/server.go b/pgproto3/example/pgfortune/server.go new file mode 100644 index 00000000..14ae71f8 --- /dev/null +++ b/pgproto3/example/pgfortune/server.go @@ -0,0 +1,104 @@ +package main + +import ( + "fmt" + "net" + + "github.com/jackc/pgx/v5/pgproto3" +) + +type PgFortuneBackend struct { + backend *pgproto3.Backend + conn net.Conn + responder func() ([]byte, error) +} + +func NewPgFortuneBackend(conn net.Conn, responder func() ([]byte, error)) *PgFortuneBackend { + backend := pgproto3.NewBackend(conn, conn) + + connHandler := &PgFortuneBackend{ + backend: backend, + conn: conn, + responder: responder, + } + + return connHandler +} + +func (p *PgFortuneBackend) Run() error { + defer p.Close() + + err := p.handleStartup() + if err != nil { + return err + } + + for { + msg, err := p.backend.Receive() + if err != nil { + return fmt.Errorf("error receiving message: %w", err) + } + + switch msg.(type) { + case *pgproto3.Query: + response, err := p.responder() + if err != nil { + return fmt.Errorf("error generating query response: %w", err) + } + + buf := (&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{ + { + Name: []byte("fortune"), + TableOID: 0, + TableAttributeNumber: 0, + DataTypeOID: 25, + DataTypeSize: -1, + TypeModifier: -1, + Format: 0, + }, + }}).Encode(nil) + buf = (&pgproto3.DataRow{Values: [][]byte{response}}).Encode(buf) + buf = (&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}).Encode(buf) + buf = (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf) + _, err = p.conn.Write(buf) + if err != nil { + return fmt.Errorf("error writing query response: %w", err) + } + case *pgproto3.Terminate: + return nil + default: + return fmt.Errorf("received message other than Query from client: %#v", msg) + } + } +} + +func (p *PgFortuneBackend) handleStartup() error { + startupMessage, err := p.backend.ReceiveStartupMessage() + if err != nil { + return fmt.Errorf("error receiving startup message: %w", err) + } + + switch startupMessage.(type) { + case *pgproto3.StartupMessage: + buf := (&pgproto3.AuthenticationOk{}).Encode(nil) + buf = (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf) + _, err = p.conn.Write(buf) + if err != nil { + return fmt.Errorf("error sending ready for query: %w", err) + } + case *pgproto3.SSLRequest: + _, err = p.conn.Write([]byte("N")) + if err != nil { + return fmt.Errorf("error sending deny SSL request: %w", err) + } + return p.handleStartup() + default: + return fmt.Errorf("unknown startup message: %#v", startupMessage) + } + + return nil +} + +func (p *PgFortuneBackend) Close() error { + return p.conn.Close() +} diff --git a/pgproto3/execute.go b/pgproto3/execute.go new file mode 100644 index 00000000..a5fee7cb --- /dev/null +++ b/pgproto3/execute.go @@ -0,0 +1,65 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type Execute struct { + Portal string + MaxRows uint32 +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*Execute) Frontend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *Execute) Decode(src []byte) error { + buf := bytes.NewBuffer(src) + + b, err := buf.ReadBytes(0) + if err != nil { + return err + } + dst.Portal = string(b[:len(b)-1]) + + if buf.Len() < 4 { + return &invalidMessageFormatErr{messageType: "Execute"} + } + dst.MaxRows = binary.BigEndian.Uint32(buf.Next(4)) + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *Execute) Encode(dst []byte) []byte { + dst = append(dst, 'E') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = append(dst, src.Portal...) + dst = append(dst, 0) + + dst = pgio.AppendUint32(dst, src.MaxRows) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src Execute) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Portal string + MaxRows uint32 + }{ + Type: "Execute", + Portal: src.Portal, + MaxRows: src.MaxRows, + }) +} diff --git a/pgproto3/flush.go b/pgproto3/flush.go new file mode 100644 index 00000000..2725f689 --- /dev/null +++ b/pgproto3/flush.go @@ -0,0 +1,34 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type Flush struct{} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*Flush) Frontend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *Flush) Decode(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "Flush", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *Flush) Encode(dst []byte) []byte { + return append(dst, 'H', 0, 0, 0, 4) +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src Flush) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "Flush", + }) +} diff --git a/pgproto3/frontend.go b/pgproto3/frontend.go new file mode 100644 index 00000000..83dea963 --- /dev/null +++ b/pgproto3/frontend.go @@ -0,0 +1,363 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" +) + +// Frontend acts as a client for the PostgreSQL wire protocol version 3. +type Frontend struct { + cr *chunkReader + w io.Writer + + // tracer is used to trace messages when Send or Receive is called. This means an outbound message is traced + // before it is actually transmitted (i.e. before Flush). It is safe to change this variable when the Frontend is + // idle. Setting and unsetting tracer provides equivalent functionality to PQtrace and PQuntrace in libpq. + tracer *tracer + + wbuf []byte + + // Backend message flyweights + authenticationOk AuthenticationOk + authenticationCleartextPassword AuthenticationCleartextPassword + authenticationMD5Password AuthenticationMD5Password + authenticationGSS AuthenticationGSS + authenticationGSSContinue AuthenticationGSSContinue + authenticationSASL AuthenticationSASL + authenticationSASLContinue AuthenticationSASLContinue + authenticationSASLFinal AuthenticationSASLFinal + backendKeyData BackendKeyData + bindComplete BindComplete + closeComplete CloseComplete + commandComplete CommandComplete + copyBothResponse CopyBothResponse + copyData CopyData + copyInResponse CopyInResponse + copyOutResponse CopyOutResponse + copyDone CopyDone + dataRow DataRow + emptyQueryResponse EmptyQueryResponse + errorResponse ErrorResponse + functionCallResponse FunctionCallResponse + noData NoData + noticeResponse NoticeResponse + notificationResponse NotificationResponse + parameterDescription ParameterDescription + parameterStatus ParameterStatus + parseComplete ParseComplete + readyForQuery ReadyForQuery + rowDescription RowDescription + portalSuspended PortalSuspended + + bodyLen int + msgType byte + partialMsg bool + authType uint32 +} + +// NewFrontend creates a new Frontend. +func NewFrontend(r io.Reader, w io.Writer) *Frontend { + cr := newChunkReader(r, 0) + return &Frontend{cr: cr, w: w} +} + +// Send sends a message to the backend (i.e. the server). The message is not guaranteed to be written until Flush is +// called. +// +// Send can work with any FrontendMessage. Some commonly used message types such as Bind have specialized send methods +// such as SendBind. These methods should be preferred when the type of message is known up front (e.g. when building an +// extended query protocol query) as they may be faster due to knowing the type of msg rather than it being hidden +// behind an interface. +func (f *Frontend) Send(msg FrontendMessage) { + prevLen := len(f.wbuf) + f.wbuf = msg.Encode(f.wbuf) + if f.tracer != nil { + f.tracer.traceMessage('F', int32(len(f.wbuf)-prevLen), msg) + } +} + +// Flush writes any pending messages to the backend (i.e. the server). +func (f *Frontend) Flush() error { + if len(f.wbuf) == 0 { + return nil + } + + n, err := f.w.Write(f.wbuf) + + const maxLen = 1024 + if len(f.wbuf) > maxLen { + f.wbuf = make([]byte, 0, maxLen) + } else { + f.wbuf = f.wbuf[:0] + } + + if err != nil { + return &writeError{err: err, safeToRetry: n == 0} + } + + return nil +} + +// Trace starts tracing the message traffic to w. It writes in a similar format to that produced by the libpq function +// PQtrace. +func (f *Frontend) Trace(w io.Writer, options TracerOptions) { + f.tracer = &tracer{ + w: w, + buf: &bytes.Buffer{}, + TracerOptions: options, + } +} + +// Untrace stops tracing. +func (f *Frontend) Untrace() { + f.tracer = nil +} + +// SendBind sends a Bind message to the backend (i.e. the server). The message is not guaranteed to be written until +// Flush is called. +func (f *Frontend) SendBind(msg *Bind) { + prevLen := len(f.wbuf) + f.wbuf = msg.Encode(f.wbuf) + if f.tracer != nil { + f.tracer.traceBind('F', int32(len(f.wbuf)-prevLen), msg) + } +} + +// SendParse sends a Parse message to the backend (i.e. the server). The message is not guaranteed to be written until +// Flush is called. +func (f *Frontend) SendParse(msg *Parse) { + prevLen := len(f.wbuf) + f.wbuf = msg.Encode(f.wbuf) + if f.tracer != nil { + f.tracer.traceParse('F', int32(len(f.wbuf)-prevLen), msg) + } +} + +// SendClose sends a Close message to the backend (i.e. the server). The message is not guaranteed to be written until +// Flush is called. +func (f *Frontend) SendClose(msg *Close) { + prevLen := len(f.wbuf) + f.wbuf = msg.Encode(f.wbuf) + if f.tracer != nil { + f.tracer.traceClose('F', int32(len(f.wbuf)-prevLen), msg) + } +} + +// SendDescribe sends a Describe message to the backend (i.e. the server). The message is not guaranteed to be written until +// Flush is called. +func (f *Frontend) SendDescribe(msg *Describe) { + prevLen := len(f.wbuf) + f.wbuf = msg.Encode(f.wbuf) + if f.tracer != nil { + f.tracer.traceDescribe('F', int32(len(f.wbuf)-prevLen), msg) + } +} + +// SendExecute sends a Execute message to the backend (i.e. the server). The message is not guaranteed to be written until +// Flush is called. +func (f *Frontend) SendExecute(msg *Execute) { + prevLen := len(f.wbuf) + f.wbuf = msg.Encode(f.wbuf) + if f.tracer != nil { + f.tracer.TraceQueryute('F', int32(len(f.wbuf)-prevLen), msg) + } +} + +// SendSync sends a Sync message to the backend (i.e. the server). The message is not guaranteed to be written until +// Flush is called. +func (f *Frontend) SendSync(msg *Sync) { + prevLen := len(f.wbuf) + f.wbuf = msg.Encode(f.wbuf) + if f.tracer != nil { + f.tracer.traceSync('F', int32(len(f.wbuf)-prevLen), msg) + } +} + +// SendQuery sends a Query message to the backend (i.e. the server). The message is not guaranteed to be written until +// Flush is called. +func (f *Frontend) SendQuery(msg *Query) { + prevLen := len(f.wbuf) + f.wbuf = msg.Encode(f.wbuf) + if f.tracer != nil { + f.tracer.traceQuery('F', int32(len(f.wbuf)-prevLen), msg) + } +} + +// SendUnbufferedEncodedCopyData immediately sends an encoded CopyData message to the backend (i.e. the server). This method +// is more efficient than sending a CopyData message with Send as the message data is not copied to the internal buffer +// before being written out. The internal buffer is flushed before the message is sent. +func (f *Frontend) SendUnbufferedEncodedCopyData(msg []byte) error { + err := f.Flush() + if err != nil { + return err + } + + n, err := f.w.Write(msg) + if err != nil { + return &writeError{err: err, safeToRetry: n == 0} + } + + if f.tracer != nil { + f.tracer.traceCopyData('F', int32(len(msg)-1), &CopyData{}) + } + + return nil +} + +func translateEOFtoErrUnexpectedEOF(err error) error { + if err == io.EOF { + return io.ErrUnexpectedEOF + } + return err +} + +// Receive receives a message from the backend. The returned message is only valid until the next call to Receive. +func (f *Frontend) Receive() (BackendMessage, error) { + if !f.partialMsg { + header, err := f.cr.Next(5) + if err != nil { + return nil, translateEOFtoErrUnexpectedEOF(err) + } + + f.msgType = header[0] + + msgLength := int(binary.BigEndian.Uint32(header[1:])) + if msgLength < 4 { + return nil, fmt.Errorf("invalid message length: %d", msgLength) + } + + f.bodyLen = msgLength - 4 + f.partialMsg = true + } + + msgBody, err := f.cr.Next(f.bodyLen) + if err != nil { + return nil, translateEOFtoErrUnexpectedEOF(err) + } + + f.partialMsg = false + + var msg BackendMessage + switch f.msgType { + case '1': + msg = &f.parseComplete + case '2': + msg = &f.bindComplete + case '3': + msg = &f.closeComplete + case 'A': + msg = &f.notificationResponse + case 'c': + msg = &f.copyDone + case 'C': + msg = &f.commandComplete + case 'd': + msg = &f.copyData + case 'D': + msg = &f.dataRow + case 'E': + msg = &f.errorResponse + case 'G': + msg = &f.copyInResponse + case 'H': + msg = &f.copyOutResponse + case 'I': + msg = &f.emptyQueryResponse + case 'K': + msg = &f.backendKeyData + case 'n': + msg = &f.noData + case 'N': + msg = &f.noticeResponse + case 'R': + var err error + msg, err = f.findAuthenticationMessageType(msgBody) + if err != nil { + return nil, err + } + case 's': + msg = &f.portalSuspended + case 'S': + msg = &f.parameterStatus + case 't': + msg = &f.parameterDescription + case 'T': + msg = &f.rowDescription + case 'V': + msg = &f.functionCallResponse + case 'W': + msg = &f.copyBothResponse + case 'Z': + msg = &f.readyForQuery + default: + return nil, fmt.Errorf("unknown message type: %c", f.msgType) + } + + err = msg.Decode(msgBody) + if err != nil { + return nil, err + } + + if f.tracer != nil { + f.tracer.traceMessage('B', int32(5+len(msgBody)), msg) + } + + return msg, nil +} + +// Authentication message type constants. +// See src/include/libpq/pqcomm.h for all +// constants. +const ( + AuthTypeOk = 0 + AuthTypeCleartextPassword = 3 + AuthTypeMD5Password = 5 + AuthTypeSCMCreds = 6 + AuthTypeGSS = 7 + AuthTypeGSSCont = 8 + AuthTypeSSPI = 9 + AuthTypeSASL = 10 + AuthTypeSASLContinue = 11 + AuthTypeSASLFinal = 12 +) + +func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, error) { + if len(src) < 4 { + return nil, errors.New("authentication message too short") + } + f.authType = binary.BigEndian.Uint32(src[:4]) + + switch f.authType { + case AuthTypeOk: + return &f.authenticationOk, nil + case AuthTypeCleartextPassword: + return &f.authenticationCleartextPassword, nil + case AuthTypeMD5Password: + return &f.authenticationMD5Password, nil + case AuthTypeSCMCreds: + return nil, errors.New("AuthTypeSCMCreds is unimplemented") + case AuthTypeGSS: + return &f.authenticationGSS, nil + case AuthTypeGSSCont: + return &f.authenticationGSSContinue, nil + case AuthTypeSSPI: + return nil, errors.New("AuthTypeSSPI is unimplemented") + case AuthTypeSASL: + return &f.authenticationSASL, nil + case AuthTypeSASLContinue: + return &f.authenticationSASLContinue, nil + case AuthTypeSASLFinal: + return &f.authenticationSASLFinal, nil + default: + return nil, fmt.Errorf("unknown authentication type: %d", f.authType) + } +} + +// GetAuthType returns the authType used in the current state of the frontend. +// See SetAuthType for more information. +func (f *Frontend) GetAuthType() uint32 { + return f.authType +} diff --git a/pgproto3/frontend_test.go b/pgproto3/frontend_test.go new file mode 100644 index 00000000..e02457d6 --- /dev/null +++ b/pgproto3/frontend_test.go @@ -0,0 +1,117 @@ +package pgproto3_test + +import ( + "io" + "testing" + + "github.com/jackc/pgx/v5/pgproto3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type interruptReader struct { + chunks [][]byte +} + +func (ir *interruptReader) Read(p []byte) (n int, err error) { + if len(ir.chunks) == 0 { + return 0, io.EOF + } + + n = copy(p, ir.chunks[0]) + if n != len(ir.chunks[0]) { + panic("this test reader doesn't support partial reads of chunks") + } + + ir.chunks = ir.chunks[1:] + + return n, nil +} + +func (ir *interruptReader) push(p []byte) { + ir.chunks = append(ir.chunks, p) +} + +func TestFrontendReceiveInterrupted(t *testing.T) { + t.Parallel() + + server := &interruptReader{} + server.push([]byte{'Z', 0, 0, 0, 5}) + + frontend := pgproto3.NewFrontend(server, nil) + + msg, err := frontend.Receive() + if err == nil { + t.Fatal("expected err") + } + if msg != nil { + t.Fatalf("did not expect msg, but %v", msg) + } + + server.push([]byte{'I'}) + + msg, err = frontend.Receive() + if err != nil { + t.Fatal(err) + } + if msg, ok := msg.(*pgproto3.ReadyForQuery); !ok || msg.TxStatus != 'I' { + t.Fatalf("unexpected msg: %v", msg) + } +} + +func TestFrontendReceiveUnexpectedEOF(t *testing.T) { + t.Parallel() + + server := &interruptReader{} + server.push([]byte{'Z', 0, 0, 0, 5}) + + frontend := pgproto3.NewFrontend(server, nil) + + msg, err := frontend.Receive() + if err == nil { + t.Fatal("expected err") + } + if msg != nil { + t.Fatalf("did not expect msg, but %v", msg) + } + + msg, err = frontend.Receive() + assert.Nil(t, msg) + assert.Equal(t, io.ErrUnexpectedEOF, err) +} + +func TestErrorResponse(t *testing.T) { + t.Parallel() + + want := &pgproto3.ErrorResponse{ + Severity: "ERROR", + SeverityUnlocalized: "ERROR", + Message: `column "foo" does not exist`, + File: "parse_relation.c", + Code: "42703", + Position: 8, + Line: 3513, + Routine: "errorMissingColumn", + } + + raw := []byte{ + 'E', 0, 0, 0, 'f', + 'S', 'E', 'R', 'R', 'O', 'R', 0, + 'V', 'E', 'R', 'R', 'O', 'R', 0, + 'C', '4', '2', '7', '0', '3', 0, + 'M', 'c', 'o', 'l', 'u', 'm', 'n', 32, '"', 'f', 'o', 'o', '"', 32, 'd', 'o', 'e', 's', 32, 'n', 'o', 't', 32, 'e', 'x', 'i', 's', 't', 0, + 'P', '8', 0, + 'F', 'p', 'a', 'r', 's', 'e', '_', 'r', 'e', 'l', 'a', 't', 'i', 'o', 'n', '.', 'c', 0, + 'L', '3', '5', '1', '3', 0, + 'R', 'e', 'r', 'r', 'o', 'r', 'M', 'i', 's', 's', 'i', 'n', 'g', 'C', 'o', 'l', 'u', 'm', 'n', 0, 0, + } + + server := &interruptReader{} + server.push(raw) + + frontend := pgproto3.NewFrontend(server, nil) + + got, err := frontend.Receive() + require.NoError(t, err) + assert.Equal(t, want, got) +} diff --git a/pgproto3/function_call.go b/pgproto3/function_call.go new file mode 100644 index 00000000..2c4f38df --- /dev/null +++ b/pgproto3/function_call.go @@ -0,0 +1,95 @@ +package pgproto3 + +import ( + "encoding/binary" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type FunctionCall struct { + Function uint32 + ArgFormatCodes []uint16 + Arguments [][]byte + ResultFormatCode uint16 +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*FunctionCall) Frontend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *FunctionCall) Decode(src []byte) error { + *dst = FunctionCall{} + rp := 0 + // Specifies the object ID of the function to call. + dst.Function = binary.BigEndian.Uint32(src[rp:]) + rp += 4 + // The number of argument format codes that follow (denoted C below). + // This can be zero to indicate that there are no arguments or that the arguments all use the default format (text); + // or one, in which case the specified format code is applied to all arguments; + // or it can equal the actual number of arguments. + nArgumentCodes := int(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + argumentCodes := make([]uint16, nArgumentCodes) + for i := 0; i < nArgumentCodes; i++ { + // The argument format codes. Each must presently be zero (text) or one (binary). + ac := binary.BigEndian.Uint16(src[rp:]) + if ac != 0 && ac != 1 { + return &invalidMessageFormatErr{messageType: "FunctionCall"} + } + argumentCodes[i] = ac + rp += 2 + } + dst.ArgFormatCodes = argumentCodes + + // Specifies the number of arguments being supplied to the function. + nArguments := int(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + arguments := make([][]byte, nArguments) + for i := 0; i < nArguments; i++ { + // The length of the argument value, in bytes (this count does not include itself). Can be zero. + // As a special case, -1 indicates a NULL argument value. No value bytes follow in the NULL case. + argumentLength := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + if argumentLength == -1 { + arguments[i] = nil + } else { + // The value of the argument, in the format indicated by the associated format code. n is the above length. + argumentValue := src[rp : rp+argumentLength] + rp += argumentLength + arguments[i] = argumentValue + } + } + dst.Arguments = arguments + // The format code for the function result. Must presently be zero (text) or one (binary). + resultFormatCode := binary.BigEndian.Uint16(src[rp:]) + if resultFormatCode != 0 && resultFormatCode != 1 { + return &invalidMessageFormatErr{messageType: "FunctionCall"} + } + dst.ResultFormatCode = resultFormatCode + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *FunctionCall) Encode(dst []byte) []byte { + dst = append(dst, 'F') + sp := len(dst) + dst = pgio.AppendUint32(dst, 0) // Unknown length, set it at the end + dst = pgio.AppendUint32(dst, src.Function) + dst = pgio.AppendUint16(dst, uint16(len(src.ArgFormatCodes))) + for _, argFormatCode := range src.ArgFormatCodes { + dst = pgio.AppendUint16(dst, argFormatCode) + } + dst = pgio.AppendUint16(dst, uint16(len(src.Arguments))) + for _, argument := range src.Arguments { + if argument == nil { + dst = pgio.AppendInt32(dst, -1) + } else { + dst = pgio.AppendInt32(dst, int32(len(argument))) + dst = append(dst, argument...) + } + } + dst = pgio.AppendUint16(dst, src.ResultFormatCode) + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + return dst +} diff --git a/pgproto3/function_call_response.go b/pgproto3/function_call_response.go new file mode 100644 index 00000000..3d3606dd --- /dev/null +++ b/pgproto3/function_call_response.go @@ -0,0 +1,101 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/hex" + "encoding/json" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type FunctionCallResponse struct { + Result []byte +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*FunctionCallResponse) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *FunctionCallResponse) Decode(src []byte) error { + if len(src) < 4 { + return &invalidMessageFormatErr{messageType: "FunctionCallResponse"} + } + rp := 0 + resultSize := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + if resultSize == -1 { + dst.Result = nil + return nil + } + + if len(src[rp:]) != resultSize { + return &invalidMessageFormatErr{messageType: "FunctionCallResponse"} + } + + dst.Result = src[rp:] + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *FunctionCallResponse) Encode(dst []byte) []byte { + dst = append(dst, 'V') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + if src.Result == nil { + dst = pgio.AppendInt32(dst, -1) + } else { + dst = pgio.AppendInt32(dst, int32(len(src.Result))) + dst = append(dst, src.Result...) + } + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src FunctionCallResponse) MarshalJSON() ([]byte, error) { + var formattedValue map[string]string + var hasNonPrintable bool + for _, b := range src.Result { + if b < 32 { + hasNonPrintable = true + break + } + } + + if hasNonPrintable { + formattedValue = map[string]string{"binary": hex.EncodeToString(src.Result)} + } else { + formattedValue = map[string]string{"text": string(src.Result)} + } + + return json.Marshal(struct { + Type string + Result map[string]string + }{ + Type: "FunctionCallResponse", + Result: formattedValue, + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *FunctionCallResponse) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Result map[string]string + } + err := json.Unmarshal(data, &msg) + if err != nil { + return err + } + dst.Result, err = getValueFromJSON(msg.Result) + return err +} diff --git a/pgproto3/function_call_test.go b/pgproto3/function_call_test.go new file mode 100644 index 00000000..8c08bb24 --- /dev/null +++ b/pgproto3/function_call_test.go @@ -0,0 +1,62 @@ +package pgproto3 + +import ( + "encoding/binary" + "reflect" + "testing" +) + +func TestFunctionCall_EncodeDecode(t *testing.T) { + type fields struct { + Function uint32 + ArgFormatCodes []uint16 + Arguments [][]byte + ResultFormatCode uint16 + } + tests := []struct { + name string + fields fields + wantErr bool + }{ + {"valid", fields{uint32(123), []uint16{0, 1, 0, 1}, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, uint16(1)}, false}, + {"invalid format code", fields{uint32(123), []uint16{2, 1, 0, 1}, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, uint16(0)}, true}, + {"invalid result format code", fields{uint32(123), []uint16{1, 1, 0, 1}, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, uint16(2)}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + src := &FunctionCall{ + Function: tt.fields.Function, + ArgFormatCodes: tt.fields.ArgFormatCodes, + Arguments: tt.fields.Arguments, + ResultFormatCode: tt.fields.ResultFormatCode, + } + encoded := src.Encode([]byte{}) + dst := &FunctionCall{} + // Check the header + msgTypeCode := encoded[0] + if msgTypeCode != 'F' { + t.Errorf("msgTypeCode %v should be 'F'", msgTypeCode) + return + } + // Check length, does not include type code character + l := binary.BigEndian.Uint32(encoded[1:5]) + if int(l) != (len(encoded) - 1) { + t.Errorf("Incorrect message length, got = %v, wanted = %v", l, len(encoded)) + } + // Check decoding works as expected + err := dst.Decode(encoded[5:]) + if err != nil { + if !tt.wantErr { + t.Errorf("FunctionCall.Decode() error = %v, wantErr %v", err, tt.wantErr) + } + return + } + + if !reflect.DeepEqual(src, dst) { + t.Error("difference after encode / decode cycle") + t.Errorf("src = %v", src) + t.Errorf("dst = %v", dst) + } + }) + } +} diff --git a/pgproto3/fuzz_test.go b/pgproto3/fuzz_test.go new file mode 100644 index 00000000..332596ab --- /dev/null +++ b/pgproto3/fuzz_test.go @@ -0,0 +1,57 @@ +package pgproto3_test + +import ( + "bytes" + "testing" + + "github.com/jackc/pgx/v5/internal/pgio" + "github.com/jackc/pgx/v5/pgproto3" + "github.com/stretchr/testify/require" +) + +func FuzzFrontend(f *testing.F) { + testcases := []struct { + msgType byte + msgLen uint32 + msgBody []byte + }{ + { + msgType: 'Z', + msgLen: 2, + msgBody: []byte{'I'}, + }, + { + msgType: 'Z', + msgLen: 5, + msgBody: []byte{'I'}, + }, + } + for _, tc := range testcases { + f.Add(tc.msgType, tc.msgLen, tc.msgBody) + } + f.Fuzz(func(t *testing.T, msgType byte, msgLen uint32, msgBody []byte) { + // Prune any msgLen > len(msgBody) because they would hang the test waiting for more input. + if int(msgLen) > len(msgBody)+4 { + return + } + + // Prune any messages that are too long. + if msgLen > 128 || len(msgBody) > 128 { + return + } + + r := &bytes.Buffer{} + w := &bytes.Buffer{} + fe := pgproto3.NewFrontend(r, w) + + var encodedMsg []byte + encodedMsg = append(encodedMsg, msgType) + encodedMsg = pgio.AppendUint32(encodedMsg, msgLen) + encodedMsg = append(encodedMsg, msgBody...) + _, err := r.Write(encodedMsg) + require.NoError(t, err) + + // Not checking anything other than no panic. + fe.Receive() + }) +} diff --git a/pgproto3/gss_enc_request.go b/pgproto3/gss_enc_request.go new file mode 100644 index 00000000..30ffc08d --- /dev/null +++ b/pgproto3/gss_enc_request.go @@ -0,0 +1,49 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +const gssEncReqNumber = 80877104 + +type GSSEncRequest struct { +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*GSSEncRequest) Frontend() {} + +func (dst *GSSEncRequest) Decode(src []byte) error { + if len(src) < 4 { + return errors.New("gss encoding request too short") + } + + requestCode := binary.BigEndian.Uint32(src) + + if requestCode != gssEncReqNumber { + return errors.New("bad gss encoding request code") + } + + return nil +} + +// Encode encodes src into dst. dst will include the 4 byte message length. +func (src *GSSEncRequest) Encode(dst []byte) []byte { + dst = pgio.AppendInt32(dst, 8) + dst = pgio.AppendInt32(dst, gssEncReqNumber) + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src GSSEncRequest) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ProtocolVersion uint32 + Parameters map[string]string + }{ + Type: "GSSEncRequest", + }) +} diff --git a/pgproto3/gss_response.go b/pgproto3/gss_response.go new file mode 100644 index 00000000..64bfbd04 --- /dev/null +++ b/pgproto3/gss_response.go @@ -0,0 +1,49 @@ +package pgproto3 + +import ( + "encoding/json" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type GSSResponse struct { + Data []byte +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (g *GSSResponse) Frontend() {} + +func (g *GSSResponse) Decode(data []byte) error { + g.Data = data + return nil +} + +func (g *GSSResponse) Encode(dst []byte) []byte { + dst = append(dst, 'p') + dst = pgio.AppendInt32(dst, int32(4+len(g.Data))) + dst = append(dst, g.Data...) + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (g *GSSResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data []byte + }{ + Type: "GSSResponse", + Data: g.Data, + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (g *GSSResponse) UnmarshalJSON(data []byte) error { + var msg struct { + Data []byte + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + g.Data = msg.Data + return nil +} diff --git a/pgproto3/json_test.go b/pgproto3/json_test.go new file mode 100644 index 00000000..8fad4f88 --- /dev/null +++ b/pgproto3/json_test.go @@ -0,0 +1,611 @@ +package pgproto3 + +import ( + "encoding/hex" + "encoding/json" + "reflect" + "testing" +) + +func TestJSONUnmarshalAuthenticationMD5Password(t *testing.T) { + data := []byte(`{"Type":"AuthenticationMD5Password", "Salt":[97,98,99,100]}`) + want := AuthenticationMD5Password{ + Salt: [4]byte{'a', 'b', 'c', 'd'}, + } + + var got AuthenticationMD5Password + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled AuthenticationMD5Password struct doesn't match expected value") + } +} + +func TestJSONUnmarshalAuthenticationSASL(t *testing.T) { + data := []byte(`{"Type":"AuthenticationSASL","AuthMechanisms":["SCRAM-SHA-256"]}`) + want := AuthenticationSASL{ + []string{"SCRAM-SHA-256"}, + } + + var got AuthenticationSASL + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled AuthenticationSASL struct doesn't match expected value") + } +} + +func TestJSONUnmarshalAuthenticationGSS(t *testing.T) { + data := []byte(`{"Type":"AuthenticationGSS"}`) + want := AuthenticationGSS{} + + var got AuthenticationGSS + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled AuthenticationGSS struct doesn't match expected value") + } +} + +func TestJSONUnmarshalAuthenticationGSSContinue(t *testing.T) { + data := []byte(`{"Type":"AuthenticationGSSContinue","Data":[1,2,3,4]}`) + want := AuthenticationGSSContinue{Data: []byte{1, 2, 3, 4}} + + var got AuthenticationGSSContinue + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled AuthenticationGSSContinue struct doesn't match expected value") + } +} + +func TestJSONUnmarshalAuthenticationSASLContinue(t *testing.T) { + data := []byte(`{"Type":"AuthenticationSASLContinue", "Data":"1"}`) + want := AuthenticationSASLContinue{ + Data: []byte{'1'}, + } + + var got AuthenticationSASLContinue + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled AuthenticationSASLContinue struct doesn't match expected value") + } +} + +func TestJSONUnmarshalAuthenticationSASLFinal(t *testing.T) { + data := []byte(`{"Type":"AuthenticationSASLFinal", "Data":"1"}`) + want := AuthenticationSASLFinal{ + Data: []byte{'1'}, + } + + var got AuthenticationSASLFinal + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled AuthenticationSASLFinal struct doesn't match expected value") + } +} + +func TestJSONUnmarshalBackendKeyData(t *testing.T) { + data := []byte(`{"Type":"BackendKeyData","ProcessID":8864,"SecretKey":3641487067}`) + want := BackendKeyData{ + ProcessID: 8864, + SecretKey: 3641487067, + } + + var got BackendKeyData + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled BackendKeyData struct doesn't match expected value") + } +} + +func TestJSONUnmarshalCommandComplete(t *testing.T) { + data := []byte(`{"Type":"CommandComplete","CommandTag":"SELECT 1"}`) + want := CommandComplete{ + CommandTag: []byte("SELECT 1"), + } + + var got CommandComplete + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled CommandComplete struct doesn't match expected value") + } +} + +func TestJSONUnmarshalCopyBothResponse(t *testing.T) { + data := []byte(`{"Type":"CopyBothResponse", "OverallFormat": "W"}`) + want := CopyBothResponse{ + OverallFormat: 'W', + } + + var got CopyBothResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled CopyBothResponse struct doesn't match expected value") + } +} + +func TestJSONUnmarshalCopyData(t *testing.T) { + data := []byte(`{"Type":"CopyData"}`) + want := CopyData{ + Data: []byte{}, + } + + var got CopyData + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled CopyData struct doesn't match expected value") + } +} + +func TestJSONUnmarshalCopyInResponse(t *testing.T) { + data := []byte(`{"Type":"CopyBothResponse", "OverallFormat": "W"}`) + want := CopyBothResponse{ + OverallFormat: 'W', + } + + var got CopyBothResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled CopyBothResponse struct doesn't match expected value") + } +} + +func TestJSONUnmarshalCopyOutResponse(t *testing.T) { + data := []byte(`{"Type":"CopyOutResponse", "OverallFormat": "W"}`) + want := CopyOutResponse{ + OverallFormat: 'W', + } + + var got CopyOutResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled CopyOutResponse struct doesn't match expected value") + } +} + +func TestJSONUnmarshalDataRow(t *testing.T) { + data := []byte(`{"Type":"DataRow","Values":[{"text":"abc"},{"text":"this is a test"},{"binary":"000263d3114d2e34"}]}`) + want := DataRow{ + Values: [][]byte{ + []byte("abc"), + []byte("this is a test"), + {0, 2, 99, 211, 17, 77, 46, 52}, + }, + } + + var got DataRow + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled DataRow struct doesn't match expected value") + } +} + +func TestJSONUnmarshalErrorResponse(t *testing.T) { + data := []byte(`{"Type":"ErrorResponse", "UnknownFields": {"97": "foo"}}`) + want := ErrorResponse{ + UnknownFields: map[byte]string{ + 'a': "foo", + }, + } + + var got ErrorResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled ErrorResponse struct doesn't match expected value") + } +} + +func TestJSONUnmarshalFunctionCallResponse(t *testing.T) { + data := []byte(`{"Type":"FunctionCallResponse"}`) + want := FunctionCallResponse{} + + var got FunctionCallResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled FunctionCallResponse struct doesn't match expected value") + } +} + +func TestJSONUnmarshalNoticeResponse(t *testing.T) { + data := []byte(`{"Type":"NoticeResponse", "UnknownFields": {"97": "foo"}}`) + want := NoticeResponse{ + UnknownFields: map[byte]string{ + 'a': "foo", + }, + } + + var got NoticeResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled NoticeResponse struct doesn't match expected value") + } +} + +func TestJSONUnmarshalNotificationResponse(t *testing.T) { + data := []byte(`{"Type":"NotificationResponse"}`) + want := NotificationResponse{} + + var got NotificationResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled NotificationResponse struct doesn't match expected value") + } +} + +func TestJSONUnmarshalParameterDescription(t *testing.T) { + data := []byte(`{"Type":"ParameterDescription", "ParameterOIDs": [25]}`) + want := ParameterDescription{ + ParameterOIDs: []uint32{25}, + } + + var got ParameterDescription + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled ParameterDescription struct doesn't match expected value") + } +} + +func TestJSONUnmarshalParameterStatus(t *testing.T) { + data := []byte(`{"Type":"ParameterStatus","Name":"TimeZone","Value":"Europe/Amsterdam"}`) + want := ParameterStatus{ + Name: "TimeZone", + Value: "Europe/Amsterdam", + } + + var got ParameterStatus + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled ParameterDescription struct doesn't match expected value") + } +} + +func TestJSONUnmarshalReadyForQuery(t *testing.T) { + data := []byte(`{"Type":"ReadyForQuery","TxStatus":"I"}`) + want := ReadyForQuery{ + TxStatus: 'I', + } + + var got ReadyForQuery + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled ParameterDescription struct doesn't match expected value") + } +} + +func TestJSONUnmarshalRowDescription(t *testing.T) { + data := []byte(`{"Type":"RowDescription","Fields":[{"Name":"generate_series","TableOID":0,"TableAttributeNumber":0,"DataTypeOID":23,"DataTypeSize":4,"TypeModifier":-1,"Format":0}]}`) + want := RowDescription{ + Fields: []FieldDescription{ + { + Name: []byte("generate_series"), + DataTypeOID: 23, + DataTypeSize: 4, + TypeModifier: -1, + }, + }, + } + + var got RowDescription + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled RowDescription struct doesn't match expected value") + } +} + +func TestJSONUnmarshalBind(t *testing.T) { + var testCases = []struct { + desc string + data []byte + }{ + { + "textual", + []byte(`{"Type":"Bind","DestinationPortal":"","PreparedStatement":"lrupsc_1_0","ParameterFormatCodes":[0],"Parameters":[{"text":"ABC-123"}],"ResultFormatCodes":[0,0,0,0,0,1,1]}`), + }, + { + "binary", + []byte(`{"Type":"Bind","DestinationPortal":"","PreparedStatement":"lrupsc_1_0","ParameterFormatCodes":[0],"Parameters":[{"binary":"` + hex.EncodeToString([]byte("ABC-123")) + `"}],"ResultFormatCodes":[0,0,0,0,0,1,1]}`), + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + var want = Bind{ + PreparedStatement: "lrupsc_1_0", + ParameterFormatCodes: []int16{0}, + Parameters: [][]byte{[]byte("ABC-123")}, + ResultFormatCodes: []int16{0, 0, 0, 0, 0, 1, 1}, + } + + var got Bind + if err := json.Unmarshal(tc.data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled Bind struct doesn't match expected value") + } + }) + } +} + +func TestJSONUnmarshalCancelRequest(t *testing.T) { + data := []byte(`{"Type":"CancelRequest","ProcessID":8864,"SecretKey":3641487067}`) + want := CancelRequest{ + ProcessID: 8864, + SecretKey: 3641487067, + } + + var got CancelRequest + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled CancelRequest struct doesn't match expected value") + } +} + +func TestJSONUnmarshalClose(t *testing.T) { + data := []byte(`{"Type":"Close","ObjectType":"S","Name":"abc"}`) + want := Close{ + ObjectType: 'S', + Name: "abc", + } + + var got Close + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled Close struct doesn't match expected value") + } +} + +func TestJSONUnmarshalCopyFail(t *testing.T) { + data := []byte(`{"Type":"CopyFail","Message":"abc"}`) + want := CopyFail{ + Message: "abc", + } + + var got CopyFail + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled CopyFail struct doesn't match expected value") + } +} + +func TestJSONUnmarshalDescribe(t *testing.T) { + data := []byte(`{"Type":"Describe","ObjectType":"S","Name":"abc"}`) + want := Describe{ + ObjectType: 'S', + Name: "abc", + } + + var got Describe + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled Describe struct doesn't match expected value") + } +} + +func TestJSONUnmarshalExecute(t *testing.T) { + data := []byte(`{"Type":"Execute","Portal":"","MaxRows":0}`) + want := Execute{} + + var got Execute + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled Execute struct doesn't match expected value") + } +} + +func TestJSONUnmarshalParse(t *testing.T) { + data := []byte(`{"Type":"Parse","Name":"lrupsc_1_0","Query":"SELECT id, name FROM t WHERE id = $1","ParameterOIDs":null}`) + want := Parse{ + Name: "lrupsc_1_0", + Query: "SELECT id, name FROM t WHERE id = $1", + } + + var got Parse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled Parse struct doesn't match expected value") + } +} + +func TestJSONUnmarshalPasswordMessage(t *testing.T) { + data := []byte(`{"Type":"PasswordMessage","Password":"abcdef"}`) + want := PasswordMessage{ + Password: "abcdef", + } + + var got PasswordMessage + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled PasswordMessage struct doesn't match expected value") + } +} + +func TestJSONUnmarshalQuery(t *testing.T) { + data := []byte(`{"Type":"Query","String":"SELECT 1"}`) + want := Query{ + String: "SELECT 1", + } + + var got Query + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled Query struct doesn't match expected value") + } +} + +func TestJSONUnmarshalSASLInitialResponse(t *testing.T) { + data := []byte(`{"Type":"SASLInitialResponse", "AuthMechanism":"SCRAM-SHA-256", "Data": "6D"}`) + want := SASLInitialResponse{ + AuthMechanism: "SCRAM-SHA-256", + Data: []byte{109}, + } + + var got SASLInitialResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled SASLInitialResponse struct doesn't match expected value") + } +} + +func TestJSONUnmarshalSASLResponse(t *testing.T) { + data := []byte(`{"Type":"SASLResponse","Message":"abc"}`) + want := SASLResponse{} + + var got SASLResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled SASLResponse struct doesn't match expected value") + } +} + +func TestJSONUnmarshalStartupMessage(t *testing.T) { + data := []byte(`{"Type":"StartupMessage","ProtocolVersion":196608,"Parameters":{"database":"testing","user":"postgres"}}`) + want := StartupMessage{ + ProtocolVersion: 196608, + Parameters: map[string]string{ + "database": "testing", + "user": "postgres", + }, + } + + var got StartupMessage + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled StartupMessage struct doesn't match expected value") + } +} + +func TestAuthenticationOK(t *testing.T) { + data := []byte(`{"Type":"AuthenticationOK"}`) + want := AuthenticationOk{} + + var got AuthenticationOk + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled AuthenticationOK struct doesn't match expected value") + } +} + +func TestAuthenticationCleartextPassword(t *testing.T) { + data := []byte(`{"Type":"AuthenticationCleartextPassword"}`) + want := AuthenticationCleartextPassword{} + + var got AuthenticationCleartextPassword + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled AuthenticationCleartextPassword struct doesn't match expected value") + } +} + +func TestAuthenticationMD5Password(t *testing.T) { + data := []byte(`{"Type":"AuthenticationMD5Password","Salt":[1,2,3,4]}`) + want := AuthenticationMD5Password{ + Salt: [4]byte{1, 2, 3, 4}, + } + + var got AuthenticationMD5Password + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled AuthenticationMD5Password struct doesn't match expected value") + } +} + +func TestJSONUnmarshalGSSResponse(t *testing.T) { + data := []byte(`{"Type":"GSSResponse","Data":[10,20,30,40]}`) + want := GSSResponse{Data: []byte{10, 20, 30, 40}} + + var got GSSResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled GSSResponse struct doesn't match expected value") + } +} + +func TestErrorResponse(t *testing.T) { + data := []byte(`{"Type":"ErrorResponse","UnknownFields":{"112":"foo"},"Code": "Fail","Position":1,"Message":"this is an error"}`) + want := ErrorResponse{ + UnknownFields: map[byte]string{ + 'p': "foo", + }, + Code: "Fail", + Position: 1, + Message: "this is an error", + } + + var got ErrorResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled ErrorResponse struct doesn't match expected value") + } +} diff --git a/pgproto3/no_data.go b/pgproto3/no_data.go new file mode 100644 index 00000000..d8f85d38 --- /dev/null +++ b/pgproto3/no_data.go @@ -0,0 +1,34 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type NoData struct{} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*NoData) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *NoData) Decode(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "NoData", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *NoData) Encode(dst []byte) []byte { + return append(dst, 'n', 0, 0, 0, 4) +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src NoData) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "NoData", + }) +} diff --git a/pgproto3/notice_response.go b/pgproto3/notice_response.go new file mode 100644 index 00000000..4ac28a79 --- /dev/null +++ b/pgproto3/notice_response.go @@ -0,0 +1,17 @@ +package pgproto3 + +type NoticeResponse ErrorResponse + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*NoticeResponse) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *NoticeResponse) Decode(src []byte) error { + return (*ErrorResponse)(dst).Decode(src) +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *NoticeResponse) Encode(dst []byte) []byte { + return append(dst, (*ErrorResponse)(src).marshalBinary('N')...) +} diff --git a/pgproto3/notification_response.go b/pgproto3/notification_response.go new file mode 100644 index 00000000..228e0dac --- /dev/null +++ b/pgproto3/notification_response.go @@ -0,0 +1,77 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type NotificationResponse struct { + PID uint32 + Channel string + Payload string +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*NotificationResponse) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *NotificationResponse) Decode(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 4 { + return &invalidMessageFormatErr{messageType: "NotificationResponse", details: "too short"} + } + + pid := binary.BigEndian.Uint32(buf.Next(4)) + + b, err := buf.ReadBytes(0) + if err != nil { + return err + } + channel := string(b[:len(b)-1]) + + b, err = buf.ReadBytes(0) + if err != nil { + return err + } + payload := string(b[:len(b)-1]) + + *dst = NotificationResponse{PID: pid, Channel: channel, Payload: payload} + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *NotificationResponse) Encode(dst []byte) []byte { + dst = append(dst, 'A') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = pgio.AppendUint32(dst, src.PID) + dst = append(dst, src.Channel...) + dst = append(dst, 0) + dst = append(dst, src.Payload...) + dst = append(dst, 0) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src NotificationResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + PID uint32 + Channel string + Payload string + }{ + Type: "NotificationResponse", + PID: src.PID, + Channel: src.Channel, + Payload: src.Payload, + }) +} diff --git a/pgproto3/parameter_description.go b/pgproto3/parameter_description.go new file mode 100644 index 00000000..374d38a3 --- /dev/null +++ b/pgproto3/parameter_description.go @@ -0,0 +1,66 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type ParameterDescription struct { + ParameterOIDs []uint32 +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*ParameterDescription) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *ParameterDescription) Decode(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 2 { + return &invalidMessageFormatErr{messageType: "ParameterDescription"} + } + + // Reported parameter count will be incorrect when number of args is greater than uint16 + buf.Next(2) + // Instead infer parameter count by remaining size of message + parameterCount := buf.Len() / 4 + + *dst = ParameterDescription{ParameterOIDs: make([]uint32, parameterCount)} + + for i := 0; i < parameterCount; i++ { + dst.ParameterOIDs[i] = binary.BigEndian.Uint32(buf.Next(4)) + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *ParameterDescription) Encode(dst []byte) []byte { + dst = append(dst, 't') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs))) + for _, oid := range src.ParameterOIDs { + dst = pgio.AppendUint32(dst, oid) + } + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src ParameterDescription) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ParameterOIDs []uint32 + }{ + Type: "ParameterDescription", + ParameterOIDs: src.ParameterOIDs, + }) +} diff --git a/pgproto3/parameter_status.go b/pgproto3/parameter_status.go new file mode 100644 index 00000000..a303e453 --- /dev/null +++ b/pgproto3/parameter_status.go @@ -0,0 +1,66 @@ +package pgproto3 + +import ( + "bytes" + "encoding/json" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type ParameterStatus struct { + Name string + Value string +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*ParameterStatus) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *ParameterStatus) Decode(src []byte) error { + buf := bytes.NewBuffer(src) + + b, err := buf.ReadBytes(0) + if err != nil { + return err + } + name := string(b[:len(b)-1]) + + b, err = buf.ReadBytes(0) + if err != nil { + return err + } + value := string(b[:len(b)-1]) + + *dst = ParameterStatus{Name: name, Value: value} + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *ParameterStatus) Encode(dst []byte) []byte { + dst = append(dst, 'S') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = append(dst, src.Name...) + dst = append(dst, 0) + dst = append(dst, src.Value...) + dst = append(dst, 0) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (ps ParameterStatus) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Name string + Value string + }{ + Type: "ParameterStatus", + Name: ps.Name, + Value: ps.Value, + }) +} diff --git a/pgproto3/parse.go b/pgproto3/parse.go new file mode 100644 index 00000000..b53200dc --- /dev/null +++ b/pgproto3/parse.go @@ -0,0 +1,88 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type Parse struct { + Name string + Query string + ParameterOIDs []uint32 +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*Parse) Frontend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *Parse) Decode(src []byte) error { + *dst = Parse{} + + buf := bytes.NewBuffer(src) + + b, err := buf.ReadBytes(0) + if err != nil { + return err + } + dst.Name = string(b[:len(b)-1]) + + b, err = buf.ReadBytes(0) + if err != nil { + return err + } + dst.Query = string(b[:len(b)-1]) + + if buf.Len() < 2 { + return &invalidMessageFormatErr{messageType: "Parse"} + } + parameterOIDCount := int(binary.BigEndian.Uint16(buf.Next(2))) + + for i := 0; i < parameterOIDCount; i++ { + if buf.Len() < 4 { + return &invalidMessageFormatErr{messageType: "Parse"} + } + dst.ParameterOIDs = append(dst.ParameterOIDs, binary.BigEndian.Uint32(buf.Next(4))) + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *Parse) Encode(dst []byte) []byte { + dst = append(dst, 'P') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = append(dst, src.Name...) + dst = append(dst, 0) + dst = append(dst, src.Query...) + dst = append(dst, 0) + + dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs))) + for _, oid := range src.ParameterOIDs { + dst = pgio.AppendUint32(dst, oid) + } + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src Parse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Name string + Query string + ParameterOIDs []uint32 + }{ + Type: "Parse", + Name: src.Name, + Query: src.Query, + ParameterOIDs: src.ParameterOIDs, + }) +} diff --git a/pgproto3/parse_complete.go b/pgproto3/parse_complete.go new file mode 100644 index 00000000..92c9498b --- /dev/null +++ b/pgproto3/parse_complete.go @@ -0,0 +1,34 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type ParseComplete struct{} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*ParseComplete) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *ParseComplete) Decode(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "ParseComplete", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *ParseComplete) Encode(dst []byte) []byte { + return append(dst, '1', 0, 0, 0, 4) +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src ParseComplete) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "ParseComplete", + }) +} diff --git a/pgproto3/password_message.go b/pgproto3/password_message.go new file mode 100644 index 00000000..41f98692 --- /dev/null +++ b/pgproto3/password_message.go @@ -0,0 +1,54 @@ +package pgproto3 + +import ( + "bytes" + "encoding/json" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type PasswordMessage struct { + Password string +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*PasswordMessage) Frontend() {} + +// Frontend identifies this message as an authentication response. +func (*PasswordMessage) InitialResponse() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *PasswordMessage) Decode(src []byte) error { + buf := bytes.NewBuffer(src) + + b, err := buf.ReadBytes(0) + if err != nil { + return err + } + dst.Password = string(b[:len(b)-1]) + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *PasswordMessage) Encode(dst []byte) []byte { + dst = append(dst, 'p') + dst = pgio.AppendInt32(dst, int32(4+len(src.Password)+1)) + + dst = append(dst, src.Password...) + dst = append(dst, 0) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src PasswordMessage) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Password string + }{ + Type: "PasswordMessage", + Password: src.Password, + }) +} diff --git a/pgproto3/pgproto3.go b/pgproto3/pgproto3.go new file mode 100644 index 00000000..ef5a5489 --- /dev/null +++ b/pgproto3/pgproto3.go @@ -0,0 +1,85 @@ +package pgproto3 + +import ( + "encoding/hex" + "errors" + "fmt" +) + +// Message is the interface implemented by an object that can decode and encode +// a particular PostgreSQL message. +type Message interface { + // Decode is allowed and expected to retain a reference to data after + // returning (unlike encoding.BinaryUnmarshaler). + Decode(data []byte) error + + // Encode appends itself to dst and returns the new buffer. + Encode(dst []byte) []byte +} + +// FrontendMessage is a message sent by the frontend (i.e. the client). +type FrontendMessage interface { + Message + Frontend() // no-op method to distinguish frontend from backend methods +} + +// BackendMessage is a message sent by the backend (i.e. the server). +type BackendMessage interface { + Message + Backend() // no-op method to distinguish frontend from backend methods +} + +type AuthenticationResponseMessage interface { + BackendMessage + AuthenticationResponse() // no-op method to distinguish authentication responses +} + +type invalidMessageLenErr struct { + messageType string + expectedLen int + actualLen int +} + +func (e *invalidMessageLenErr) Error() string { + return fmt.Sprintf("%s body must have length of %d, but it is %d", e.messageType, e.expectedLen, e.actualLen) +} + +type invalidMessageFormatErr struct { + messageType string + details string +} + +func (e *invalidMessageFormatErr) Error() string { + return fmt.Sprintf("%s body is invalid %s", e.messageType, e.details) +} + +type writeError struct { + err error + safeToRetry bool +} + +func (e *writeError) Error() string { + return fmt.Sprintf("write failed: %s", e.err.Error()) +} + +func (e *writeError) SafeToRetry() bool { + return e.safeToRetry +} + +func (e *writeError) Unwrap() error { + return e.err +} + +// getValueFromJSON gets the value from a protocol message representation in JSON. +func getValueFromJSON(v map[string]string) ([]byte, error) { + if v == nil { + return nil, nil + } + if text, ok := v["text"]; ok { + return []byte(text), nil + } + if binary, ok := v["binary"]; ok { + return hex.DecodeString(binary) + } + return nil, errors.New("unknown protocol representation") +} diff --git a/pgproto3/portal_suspended.go b/pgproto3/portal_suspended.go new file mode 100644 index 00000000..1a9e7bfb --- /dev/null +++ b/pgproto3/portal_suspended.go @@ -0,0 +1,34 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type PortalSuspended struct{} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*PortalSuspended) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *PortalSuspended) Decode(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "PortalSuspended", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *PortalSuspended) Encode(dst []byte) []byte { + return append(dst, 's', 0, 0, 0, 4) +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src PortalSuspended) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "PortalSuspended", + }) +} diff --git a/pgproto3/query.go b/pgproto3/query.go new file mode 100644 index 00000000..e963a0ec --- /dev/null +++ b/pgproto3/query.go @@ -0,0 +1,50 @@ +package pgproto3 + +import ( + "bytes" + "encoding/json" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type Query struct { + String string +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*Query) Frontend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *Query) Decode(src []byte) error { + i := bytes.IndexByte(src, 0) + if i != len(src)-1 { + return &invalidMessageFormatErr{messageType: "Query"} + } + + dst.String = string(src[:i]) + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *Query) Encode(dst []byte) []byte { + dst = append(dst, 'Q') + dst = pgio.AppendInt32(dst, int32(4+len(src.String)+1)) + + dst = append(dst, src.String...) + dst = append(dst, 0) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src Query) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + String string + }{ + Type: "Query", + String: src.String, + }) +} diff --git a/pgproto3/ready_for_query.go b/pgproto3/ready_for_query.go new file mode 100644 index 00000000..67a39be3 --- /dev/null +++ b/pgproto3/ready_for_query.go @@ -0,0 +1,61 @@ +package pgproto3 + +import ( + "encoding/json" + "errors" +) + +type ReadyForQuery struct { + TxStatus byte +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*ReadyForQuery) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *ReadyForQuery) Decode(src []byte) error { + if len(src) != 1 { + return &invalidMessageLenErr{messageType: "ReadyForQuery", expectedLen: 1, actualLen: len(src)} + } + + dst.TxStatus = src[0] + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *ReadyForQuery) Encode(dst []byte) []byte { + return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus) +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src ReadyForQuery) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + TxStatus string + }{ + Type: "ReadyForQuery", + TxStatus: string(src.TxStatus), + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *ReadyForQuery) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + TxStatus string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + if len(msg.TxStatus) != 1 { + return errors.New("invalid length for ReadyForQuery.TxStatus") + } + dst.TxStatus = msg.TxStatus[0] + return nil +} diff --git a/pgproto3/row_description.go b/pgproto3/row_description.go new file mode 100644 index 00000000..6f6f0681 --- /dev/null +++ b/pgproto3/row_description.go @@ -0,0 +1,165 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +const ( + TextFormat = 0 + BinaryFormat = 1 +) + +type FieldDescription struct { + Name []byte + TableOID uint32 + TableAttributeNumber uint16 + DataTypeOID uint32 + DataTypeSize int16 + TypeModifier int32 + Format int16 +} + +// MarshalJSON implements encoding/json.Marshaler. +func (fd FieldDescription) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Name string + TableOID uint32 + TableAttributeNumber uint16 + DataTypeOID uint32 + DataTypeSize int16 + TypeModifier int32 + Format int16 + }{ + Name: string(fd.Name), + TableOID: fd.TableOID, + TableAttributeNumber: fd.TableAttributeNumber, + DataTypeOID: fd.DataTypeOID, + DataTypeSize: fd.DataTypeSize, + TypeModifier: fd.TypeModifier, + Format: fd.Format, + }) +} + +type RowDescription struct { + Fields []FieldDescription +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*RowDescription) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *RowDescription) Decode(src []byte) error { + + if len(src) < 2 { + return &invalidMessageFormatErr{messageType: "RowDescription"} + } + fieldCount := int(binary.BigEndian.Uint16(src)) + rp := 2 + + dst.Fields = dst.Fields[0:0] + + for i := 0; i < fieldCount; i++ { + var fd FieldDescription + + idx := bytes.IndexByte(src[rp:], 0) + if idx < 0 { + return &invalidMessageFormatErr{messageType: "RowDescription"} + } + fd.Name = src[rp : rp+idx] + rp += idx + 1 + + // Since buf.Next() doesn't return an error if we hit the end of the buffer + // check Len ahead of time + if len(src[rp:]) < 18 { + return &invalidMessageFormatErr{messageType: "RowDescription"} + } + + fd.TableOID = binary.BigEndian.Uint32(src[rp:]) + rp += 4 + fd.TableAttributeNumber = binary.BigEndian.Uint16(src[rp:]) + rp += 2 + fd.DataTypeOID = binary.BigEndian.Uint32(src[rp:]) + rp += 4 + fd.DataTypeSize = int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + fd.TypeModifier = int32(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + fd.Format = int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + + dst.Fields = append(dst.Fields, fd) + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *RowDescription) Encode(dst []byte) []byte { + dst = append(dst, 'T') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = pgio.AppendUint16(dst, uint16(len(src.Fields))) + for _, fd := range src.Fields { + dst = append(dst, fd.Name...) + dst = append(dst, 0) + + dst = pgio.AppendUint32(dst, fd.TableOID) + dst = pgio.AppendUint16(dst, fd.TableAttributeNumber) + dst = pgio.AppendUint32(dst, fd.DataTypeOID) + dst = pgio.AppendInt16(dst, fd.DataTypeSize) + dst = pgio.AppendInt32(dst, fd.TypeModifier) + dst = pgio.AppendInt16(dst, fd.Format) + } + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src RowDescription) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Fields []FieldDescription + }{ + Type: "RowDescription", + Fields: src.Fields, + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *RowDescription) UnmarshalJSON(data []byte) error { + var msg struct { + Fields []struct { + Name string + TableOID uint32 + TableAttributeNumber uint16 + DataTypeOID uint32 + DataTypeSize int16 + TypeModifier int32 + Format int16 + } + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + dst.Fields = make([]FieldDescription, len(msg.Fields)) + for n, field := range msg.Fields { + dst.Fields[n] = FieldDescription{ + Name: []byte(field.Name), + TableOID: field.TableOID, + TableAttributeNumber: field.TableAttributeNumber, + DataTypeOID: field.DataTypeOID, + DataTypeSize: field.DataTypeSize, + TypeModifier: field.TypeModifier, + Format: field.Format, + } + } + return nil +} diff --git a/pgproto3/sasl_initial_response.go b/pgproto3/sasl_initial_response.go new file mode 100644 index 00000000..eeda4691 --- /dev/null +++ b/pgproto3/sasl_initial_response.go @@ -0,0 +1,94 @@ +package pgproto3 + +import ( + "bytes" + "encoding/hex" + "encoding/json" + "errors" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type SASLInitialResponse struct { + AuthMechanism string + Data []byte +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*SASLInitialResponse) Frontend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *SASLInitialResponse) Decode(src []byte) error { + *dst = SASLInitialResponse{} + + rp := 0 + + idx := bytes.IndexByte(src, 0) + if idx < 0 { + return errors.New("invalid SASLInitialResponse") + } + + dst.AuthMechanism = string(src[rp:idx]) + rp = idx + 1 + + rp += 4 // The rest of the message is data so we can just skip the size + dst.Data = src[rp:] + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *SASLInitialResponse) Encode(dst []byte) []byte { + dst = append(dst, 'p') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = append(dst, []byte(src.AuthMechanism)...) + dst = append(dst, 0) + + dst = pgio.AppendInt32(dst, int32(len(src.Data))) + dst = append(dst, src.Data...) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src SASLInitialResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + AuthMechanism string + Data string + }{ + Type: "SASLInitialResponse", + AuthMechanism: src.AuthMechanism, + Data: string(src.Data), + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *SASLInitialResponse) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + AuthMechanism string + Data string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + dst.AuthMechanism = msg.AuthMechanism + if msg.Data != "" { + decoded, err := hex.DecodeString(msg.Data) + if err != nil { + return err + } + dst.Data = decoded + } + return nil +} diff --git a/pgproto3/sasl_response.go b/pgproto3/sasl_response.go new file mode 100644 index 00000000..54c3d96f --- /dev/null +++ b/pgproto3/sasl_response.go @@ -0,0 +1,61 @@ +package pgproto3 + +import ( + "encoding/hex" + "encoding/json" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type SASLResponse struct { + Data []byte +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*SASLResponse) Frontend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *SASLResponse) Decode(src []byte) error { + *dst = SASLResponse{Data: src} + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *SASLResponse) Encode(dst []byte) []byte { + dst = append(dst, 'p') + dst = pgio.AppendInt32(dst, int32(4+len(src.Data))) + + dst = append(dst, src.Data...) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src SASLResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data string + }{ + Type: "SASLResponse", + Data: string(src.Data), + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *SASLResponse) UnmarshalJSON(data []byte) error { + var msg struct { + Data string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + if msg.Data != "" { + decoded, err := hex.DecodeString(msg.Data) + if err != nil { + return err + } + dst.Data = decoded + } + return nil +} diff --git a/pgproto3/ssl_request.go b/pgproto3/ssl_request.go new file mode 100644 index 00000000..1b00c16b --- /dev/null +++ b/pgproto3/ssl_request.go @@ -0,0 +1,49 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +const sslRequestNumber = 80877103 + +type SSLRequest struct { +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*SSLRequest) Frontend() {} + +func (dst *SSLRequest) Decode(src []byte) error { + if len(src) < 4 { + return errors.New("ssl request too short") + } + + requestCode := binary.BigEndian.Uint32(src) + + if requestCode != sslRequestNumber { + return errors.New("bad ssl request code") + } + + return nil +} + +// Encode encodes src into dst. dst will include the 4 byte message length. +func (src *SSLRequest) Encode(dst []byte) []byte { + dst = pgio.AppendInt32(dst, 8) + dst = pgio.AppendInt32(dst, sslRequestNumber) + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src SSLRequest) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ProtocolVersion uint32 + Parameters map[string]string + }{ + Type: "SSLRequest", + }) +} diff --git a/pgproto3/startup_message.go b/pgproto3/startup_message.go new file mode 100644 index 00000000..5c974f02 --- /dev/null +++ b/pgproto3/startup_message.go @@ -0,0 +1,96 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" + "errors" + "fmt" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +const ProtocolVersionNumber = 196608 // 3.0 + +type StartupMessage struct { + ProtocolVersion uint32 + Parameters map[string]string +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*StartupMessage) Frontend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *StartupMessage) Decode(src []byte) error { + if len(src) < 4 { + return errors.New("startup message too short") + } + + dst.ProtocolVersion = binary.BigEndian.Uint32(src) + rp := 4 + + if dst.ProtocolVersion != ProtocolVersionNumber { + return fmt.Errorf("Bad startup message version number. Expected %d, got %d", ProtocolVersionNumber, dst.ProtocolVersion) + } + + dst.Parameters = make(map[string]string) + for { + idx := bytes.IndexByte(src[rp:], 0) + if idx < 0 { + return &invalidMessageFormatErr{messageType: "StartupMesage"} + } + key := string(src[rp : rp+idx]) + rp += idx + 1 + + idx = bytes.IndexByte(src[rp:], 0) + if idx < 0 { + return &invalidMessageFormatErr{messageType: "StartupMesage"} + } + value := string(src[rp : rp+idx]) + rp += idx + 1 + + dst.Parameters[key] = value + + if len(src[rp:]) == 1 { + if src[rp] != 0 { + return fmt.Errorf("Bad startup message last byte. Expected 0, got %d", src[rp]) + } + break + } + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *StartupMessage) Encode(dst []byte) []byte { + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = pgio.AppendUint32(dst, src.ProtocolVersion) + for k, v := range src.Parameters { + dst = append(dst, k...) + dst = append(dst, 0) + dst = append(dst, v...) + dst = append(dst, 0) + } + dst = append(dst, 0) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src StartupMessage) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ProtocolVersion uint32 + Parameters map[string]string + }{ + Type: "StartupMessage", + ProtocolVersion: src.ProtocolVersion, + Parameters: src.Parameters, + }) +} diff --git a/pgproto3/sync.go b/pgproto3/sync.go new file mode 100644 index 00000000..5db8e07a --- /dev/null +++ b/pgproto3/sync.go @@ -0,0 +1,34 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type Sync struct{} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*Sync) Frontend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *Sync) Decode(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "Sync", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *Sync) Encode(dst []byte) []byte { + return append(dst, 'S', 0, 0, 0, 4) +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src Sync) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "Sync", + }) +} diff --git a/pgproto3/terminate.go b/pgproto3/terminate.go new file mode 100644 index 00000000..135191ea --- /dev/null +++ b/pgproto3/terminate.go @@ -0,0 +1,34 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type Terminate struct{} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*Terminate) Frontend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *Terminate) Decode(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "Terminate", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *Terminate) Encode(dst []byte) []byte { + return append(dst, 'X', 0, 0, 0, 4) +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src Terminate) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "Terminate", + }) +} diff --git a/pgproto3/testdata/fuzz/FuzzFrontend/39c5e864da4707fc15fea48f7062d6a07796fdc43b33e0ba9dbd7074a0211fa6 b/pgproto3/testdata/fuzz/FuzzFrontend/39c5e864da4707fc15fea48f7062d6a07796fdc43b33e0ba9dbd7074a0211fa6 new file mode 100644 index 00000000..d1c612d3 --- /dev/null +++ b/pgproto3/testdata/fuzz/FuzzFrontend/39c5e864da4707fc15fea48f7062d6a07796fdc43b33e0ba9dbd7074a0211fa6 @@ -0,0 +1,4 @@ +go test fuzz v1 +byte('A') +uint32(5) +[]byte("0") diff --git a/pgproto3/testdata/fuzz/FuzzFrontend/9b06792b1aaac8a907dbfa04d526ae14326c8573b7409032caac8461e83065f7 b/pgproto3/testdata/fuzz/FuzzFrontend/9b06792b1aaac8a907dbfa04d526ae14326c8573b7409032caac8461e83065f7 new file mode 100644 index 00000000..763b70ae --- /dev/null +++ b/pgproto3/testdata/fuzz/FuzzFrontend/9b06792b1aaac8a907dbfa04d526ae14326c8573b7409032caac8461e83065f7 @@ -0,0 +1,4 @@ +go test fuzz v1 +byte('D') +uint32(21) +[]byte("00\xb300000000000000") diff --git a/pgproto3/testdata/fuzz/FuzzFrontend/a661fb98e802839f0a7361160fbc6e28794612a411d00bde104364ee281c4214 b/pgproto3/testdata/fuzz/FuzzFrontend/a661fb98e802839f0a7361160fbc6e28794612a411d00bde104364ee281c4214 new file mode 100644 index 00000000..3d995c28 --- /dev/null +++ b/pgproto3/testdata/fuzz/FuzzFrontend/a661fb98e802839f0a7361160fbc6e28794612a411d00bde104364ee281c4214 @@ -0,0 +1,4 @@ +go test fuzz v1 +byte('C') +uint32(4) +[]byte("0") diff --git a/pgproto3/testdata/fuzz/FuzzFrontend/fc98dcd487a5173b38763a5f7dd023933f3a86ab566e3f2b091eb36248107eb4 b/pgproto3/testdata/fuzz/FuzzFrontend/fc98dcd487a5173b38763a5f7dd023933f3a86ab566e3f2b091eb36248107eb4 new file mode 100644 index 00000000..45f0ba81 --- /dev/null +++ b/pgproto3/testdata/fuzz/FuzzFrontend/fc98dcd487a5173b38763a5f7dd023933f3a86ab566e3f2b091eb36248107eb4 @@ -0,0 +1,4 @@ +go test fuzz v1 +byte('R') +uint32(13) +[]byte("\x00\x00\x00\n0\x12\xebG\x8dI']G\xdac\x95\xb7\x18\xb0\x02\xe8m\xc2\x00\xef\x03\x12\x1b\xbdj\x10\x9f\xf9\xeb\xb8") diff --git a/pgproto3/trace.go b/pgproto3/trace.go new file mode 100644 index 00000000..c09f68d1 --- /dev/null +++ b/pgproto3/trace.go @@ -0,0 +1,440 @@ +package pgproto3 + +import ( + "bytes" + "fmt" + "io" + "strconv" + "strings" + "time" +) + +// tracer traces the messages send to and from a Backend or Frontend. The format it produces roughly mimics the +// format produced by the libpq C function PQtrace. +type tracer struct { + w io.Writer + buf *bytes.Buffer + TracerOptions +} + +// TracerOptions controls tracing behavior. It is roughly equivalent to the libpq function PQsetTraceFlags. +type TracerOptions struct { + // SuppressTimestamps prevents printing of timestamps. + SuppressTimestamps bool + + // RegressMode redacts fields that may be vary between executions. + RegressMode bool +} + +func (t *tracer) traceMessage(sender byte, encodedLen int32, msg Message) { + switch msg := msg.(type) { + case *AuthenticationCleartextPassword: + t.traceAuthenticationCleartextPassword(sender, encodedLen, msg) + case *AuthenticationGSS: + t.traceAuthenticationGSS(sender, encodedLen, msg) + case *AuthenticationGSSContinue: + t.traceAuthenticationGSSContinue(sender, encodedLen, msg) + case *AuthenticationMD5Password: + t.traceAuthenticationMD5Password(sender, encodedLen, msg) + case *AuthenticationOk: + t.traceAuthenticationOk(sender, encodedLen, msg) + case *AuthenticationSASL: + t.traceAuthenticationSASL(sender, encodedLen, msg) + case *AuthenticationSASLContinue: + t.traceAuthenticationSASLContinue(sender, encodedLen, msg) + case *AuthenticationSASLFinal: + t.traceAuthenticationSASLFinal(sender, encodedLen, msg) + case *BackendKeyData: + t.traceBackendKeyData(sender, encodedLen, msg) + case *Bind: + t.traceBind(sender, encodedLen, msg) + case *BindComplete: + t.traceBindComplete(sender, encodedLen, msg) + case *CancelRequest: + t.traceCancelRequest(sender, encodedLen, msg) + case *Close: + t.traceClose(sender, encodedLen, msg) + case *CloseComplete: + t.traceCloseComplete(sender, encodedLen, msg) + case *CommandComplete: + t.traceCommandComplete(sender, encodedLen, msg) + case *CopyBothResponse: + t.traceCopyBothResponse(sender, encodedLen, msg) + case *CopyData: + t.traceCopyData(sender, encodedLen, msg) + case *CopyDone: + t.traceCopyDone(sender, encodedLen, msg) + case *CopyFail: + t.traceCopyFail(sender, encodedLen, msg) + case *CopyInResponse: + t.traceCopyInResponse(sender, encodedLen, msg) + case *CopyOutResponse: + t.traceCopyOutResponse(sender, encodedLen, msg) + case *DataRow: + t.traceDataRow(sender, encodedLen, msg) + case *Describe: + t.traceDescribe(sender, encodedLen, msg) + case *EmptyQueryResponse: + t.traceEmptyQueryResponse(sender, encodedLen, msg) + case *ErrorResponse: + t.traceErrorResponse(sender, encodedLen, msg) + case *Execute: + t.TraceQueryute(sender, encodedLen, msg) + case *Flush: + t.traceFlush(sender, encodedLen, msg) + case *FunctionCall: + t.traceFunctionCall(sender, encodedLen, msg) + case *FunctionCallResponse: + t.traceFunctionCallResponse(sender, encodedLen, msg) + case *GSSEncRequest: + t.traceGSSEncRequest(sender, encodedLen, msg) + case *NoData: + t.traceNoData(sender, encodedLen, msg) + case *NoticeResponse: + t.traceNoticeResponse(sender, encodedLen, msg) + case *NotificationResponse: + t.traceNotificationResponse(sender, encodedLen, msg) + case *ParameterDescription: + t.traceParameterDescription(sender, encodedLen, msg) + case *ParameterStatus: + t.traceParameterStatus(sender, encodedLen, msg) + case *Parse: + t.traceParse(sender, encodedLen, msg) + case *ParseComplete: + t.traceParseComplete(sender, encodedLen, msg) + case *PortalSuspended: + t.tracePortalSuspended(sender, encodedLen, msg) + case *Query: + t.traceQuery(sender, encodedLen, msg) + case *ReadyForQuery: + t.traceReadyForQuery(sender, encodedLen, msg) + case *RowDescription: + t.traceRowDescription(sender, encodedLen, msg) + case *SSLRequest: + t.traceSSLRequest(sender, encodedLen, msg) + case *StartupMessage: + t.traceStartupMessage(sender, encodedLen, msg) + case *Sync: + t.traceSync(sender, encodedLen, msg) + case *Terminate: + t.traceTerminate(sender, encodedLen, msg) + default: + t.beginTrace(sender, encodedLen, "Unknown") + t.finishTrace() + } +} + +func (t *tracer) traceAuthenticationCleartextPassword(sender byte, encodedLen int32, msg *AuthenticationCleartextPassword) { + t.beginTrace(sender, encodedLen, "AuthenticationCleartextPassword") + t.finishTrace() +} + +func (t *tracer) traceAuthenticationGSS(sender byte, encodedLen int32, msg *AuthenticationGSS) { + t.beginTrace(sender, encodedLen, "AuthenticationGSS") + t.finishTrace() +} + +func (t *tracer) traceAuthenticationGSSContinue(sender byte, encodedLen int32, msg *AuthenticationGSSContinue) { + t.beginTrace(sender, encodedLen, "AuthenticationGSSContinue") + t.finishTrace() +} + +func (t *tracer) traceAuthenticationMD5Password(sender byte, encodedLen int32, msg *AuthenticationMD5Password) { + t.beginTrace(sender, encodedLen, "AuthenticationMD5Password") + t.finishTrace() +} + +func (t *tracer) traceAuthenticationOk(sender byte, encodedLen int32, msg *AuthenticationOk) { + t.beginTrace(sender, encodedLen, "AuthenticationOk") + t.finishTrace() +} + +func (t *tracer) traceAuthenticationSASL(sender byte, encodedLen int32, msg *AuthenticationSASL) { + t.beginTrace(sender, encodedLen, "AuthenticationSASL") + t.finishTrace() +} + +func (t *tracer) traceAuthenticationSASLContinue(sender byte, encodedLen int32, msg *AuthenticationSASLContinue) { + t.beginTrace(sender, encodedLen, "AuthenticationSASLContinue") + t.finishTrace() +} + +func (t *tracer) traceAuthenticationSASLFinal(sender byte, encodedLen int32, msg *AuthenticationSASLFinal) { + t.beginTrace(sender, encodedLen, "AuthenticationSASLFinal") + t.finishTrace() +} + +func (t *tracer) traceBackendKeyData(sender byte, encodedLen int32, msg *BackendKeyData) { + t.beginTrace(sender, encodedLen, "BackendKeyData") + if t.RegressMode { + t.buf.WriteString("\t NNNN NNNN") + } else { + fmt.Fprintf(t.buf, "\t %d %d", msg.ProcessID, msg.SecretKey) + } + t.finishTrace() +} + +func (t *tracer) traceBind(sender byte, encodedLen int32, msg *Bind) { + t.beginTrace(sender, encodedLen, "Bind") + fmt.Fprintf(t.buf, "\t %s %s %d", traceDoubleQuotedString([]byte(msg.DestinationPortal)), traceDoubleQuotedString([]byte(msg.PreparedStatement)), len(msg.ParameterFormatCodes)) + for _, fc := range msg.ParameterFormatCodes { + fmt.Fprintf(t.buf, " %d", fc) + } + fmt.Fprintf(t.buf, " %d", len(msg.Parameters)) + for _, p := range msg.Parameters { + fmt.Fprintf(t.buf, " %s", traceSingleQuotedString(p)) + } + fmt.Fprintf(t.buf, " %d", len(msg.ResultFormatCodes)) + for _, fc := range msg.ResultFormatCodes { + fmt.Fprintf(t.buf, " %d", fc) + } + t.finishTrace() +} + +func (t *tracer) traceBindComplete(sender byte, encodedLen int32, msg *BindComplete) { + t.beginTrace(sender, encodedLen, "BindComplete") + t.finishTrace() +} + +func (t *tracer) traceCancelRequest(sender byte, encodedLen int32, msg *CancelRequest) { + t.beginTrace(sender, encodedLen, "CancelRequest") + t.finishTrace() +} + +func (t *tracer) traceClose(sender byte, encodedLen int32, msg *Close) { + t.beginTrace(sender, encodedLen, "Close") + t.finishTrace() +} + +func (t *tracer) traceCloseComplete(sender byte, encodedLen int32, msg *CloseComplete) { + t.beginTrace(sender, encodedLen, "CloseComplete") + t.finishTrace() +} + +func (t *tracer) traceCommandComplete(sender byte, encodedLen int32, msg *CommandComplete) { + t.beginTrace(sender, encodedLen, "CommandComplete") + fmt.Fprintf(t.buf, "\t %s", traceDoubleQuotedString(msg.CommandTag)) + t.finishTrace() +} + +func (t *tracer) traceCopyBothResponse(sender byte, encodedLen int32, msg *CopyBothResponse) { + t.beginTrace(sender, encodedLen, "CopyBothResponse") + t.finishTrace() +} + +func (t *tracer) traceCopyData(sender byte, encodedLen int32, msg *CopyData) { + t.beginTrace(sender, encodedLen, "CopyData") + t.finishTrace() +} + +func (t *tracer) traceCopyDone(sender byte, encodedLen int32, msg *CopyDone) { + t.beginTrace(sender, encodedLen, "CopyDone") + t.finishTrace() +} + +func (t *tracer) traceCopyFail(sender byte, encodedLen int32, msg *CopyFail) { + t.beginTrace(sender, encodedLen, "CopyFail") + fmt.Fprintf(t.buf, "\t %s", traceDoubleQuotedString([]byte(msg.Message))) + t.finishTrace() +} + +func (t *tracer) traceCopyInResponse(sender byte, encodedLen int32, msg *CopyInResponse) { + t.beginTrace(sender, encodedLen, "CopyInResponse") + t.finishTrace() +} + +func (t *tracer) traceCopyOutResponse(sender byte, encodedLen int32, msg *CopyOutResponse) { + t.beginTrace(sender, encodedLen, "CopyOutResponse") + t.finishTrace() +} + +func (t *tracer) traceDataRow(sender byte, encodedLen int32, msg *DataRow) { + t.beginTrace(sender, encodedLen, "DataRow") + fmt.Fprintf(t.buf, "\t %d", len(msg.Values)) + for _, v := range msg.Values { + if v == nil { + t.buf.WriteString(" -1") + } else { + fmt.Fprintf(t.buf, " %d %s", len(v), traceSingleQuotedString(v)) + } + } + t.finishTrace() +} + +func (t *tracer) traceDescribe(sender byte, encodedLen int32, msg *Describe) { + t.beginTrace(sender, encodedLen, "Describe") + fmt.Fprintf(t.buf, "\t %c %s", msg.ObjectType, traceDoubleQuotedString([]byte(msg.Name))) + t.finishTrace() +} + +func (t *tracer) traceEmptyQueryResponse(sender byte, encodedLen int32, msg *EmptyQueryResponse) { + t.beginTrace(sender, encodedLen, "EmptyQueryResponse") + t.finishTrace() +} + +func (t *tracer) traceErrorResponse(sender byte, encodedLen int32, msg *ErrorResponse) { + t.beginTrace(sender, encodedLen, "ErrorResponse") + t.finishTrace() +} + +func (t *tracer) TraceQueryute(sender byte, encodedLen int32, msg *Execute) { + t.beginTrace(sender, encodedLen, "Execute") + fmt.Fprintf(t.buf, "\t %s %d", traceDoubleQuotedString([]byte(msg.Portal)), msg.MaxRows) + t.finishTrace() +} + +func (t *tracer) traceFlush(sender byte, encodedLen int32, msg *Flush) { + t.beginTrace(sender, encodedLen, "Flush") + t.finishTrace() +} + +func (t *tracer) traceFunctionCall(sender byte, encodedLen int32, msg *FunctionCall) { + t.beginTrace(sender, encodedLen, "FunctionCall") + t.finishTrace() +} + +func (t *tracer) traceFunctionCallResponse(sender byte, encodedLen int32, msg *FunctionCallResponse) { + t.beginTrace(sender, encodedLen, "FunctionCallResponse") + t.finishTrace() +} + +func (t *tracer) traceGSSEncRequest(sender byte, encodedLen int32, msg *GSSEncRequest) { + t.beginTrace(sender, encodedLen, "GSSEncRequest") + t.finishTrace() +} + +func (t *tracer) traceNoData(sender byte, encodedLen int32, msg *NoData) { + t.beginTrace(sender, encodedLen, "NoData") + t.finishTrace() +} + +func (t *tracer) traceNoticeResponse(sender byte, encodedLen int32, msg *NoticeResponse) { + t.beginTrace(sender, encodedLen, "NoticeResponse") + t.finishTrace() +} + +func (t *tracer) traceNotificationResponse(sender byte, encodedLen int32, msg *NotificationResponse) { + t.beginTrace(sender, encodedLen, "NotificationResponse") + fmt.Fprintf(t.buf, "\t %d %s %s", msg.PID, traceDoubleQuotedString([]byte(msg.Channel)), traceDoubleQuotedString([]byte(msg.Payload))) + t.finishTrace() +} + +func (t *tracer) traceParameterDescription(sender byte, encodedLen int32, msg *ParameterDescription) { + t.beginTrace(sender, encodedLen, "ParameterDescription") + t.finishTrace() +} + +func (t *tracer) traceParameterStatus(sender byte, encodedLen int32, msg *ParameterStatus) { + t.beginTrace(sender, encodedLen, "ParameterStatus") + fmt.Fprintf(t.buf, "\t %s %s", traceDoubleQuotedString([]byte(msg.Name)), traceDoubleQuotedString([]byte(msg.Value))) + t.finishTrace() +} + +func (t *tracer) traceParse(sender byte, encodedLen int32, msg *Parse) { + t.beginTrace(sender, encodedLen, "Parse") + fmt.Fprintf(t.buf, "\t %s %s %d", traceDoubleQuotedString([]byte(msg.Name)), traceDoubleQuotedString([]byte(msg.Query)), len(msg.ParameterOIDs)) + for _, oid := range msg.ParameterOIDs { + fmt.Fprintf(t.buf, " %d", oid) + } + t.finishTrace() +} + +func (t *tracer) traceParseComplete(sender byte, encodedLen int32, msg *ParseComplete) { + t.beginTrace(sender, encodedLen, "ParseComplete") + t.finishTrace() +} + +func (t *tracer) tracePortalSuspended(sender byte, encodedLen int32, msg *PortalSuspended) { + t.beginTrace(sender, encodedLen, "PortalSuspended") + t.finishTrace() +} + +func (t *tracer) traceQuery(sender byte, encodedLen int32, msg *Query) { + t.beginTrace(sender, encodedLen, "Query") + fmt.Fprintf(t.buf, "\t %s", traceDoubleQuotedString([]byte(msg.String))) + t.finishTrace() +} + +func (t *tracer) traceReadyForQuery(sender byte, encodedLen int32, msg *ReadyForQuery) { + t.beginTrace(sender, encodedLen, "ReadyForQuery") + fmt.Fprintf(t.buf, "\t %c", msg.TxStatus) + t.finishTrace() +} + +func (t *tracer) traceRowDescription(sender byte, encodedLen int32, msg *RowDescription) { + t.beginTrace(sender, encodedLen, "RowDescription") + fmt.Fprintf(t.buf, "\t %d", len(msg.Fields)) + for _, fd := range msg.Fields { + fmt.Fprintf(t.buf, ` %s %d %d %d %d %d %d`, traceDoubleQuotedString(fd.Name), fd.TableOID, fd.TableAttributeNumber, fd.DataTypeOID, fd.DataTypeSize, fd.TypeModifier, fd.Format) + } + t.finishTrace() +} + +func (t *tracer) traceSSLRequest(sender byte, encodedLen int32, msg *SSLRequest) { + t.beginTrace(sender, encodedLen, "SSLRequest") + t.finishTrace() +} + +func (t *tracer) traceStartupMessage(sender byte, encodedLen int32, msg *StartupMessage) { + t.beginTrace(sender, encodedLen, "StartupMessage") + t.finishTrace() +} + +func (t *tracer) traceSync(sender byte, encodedLen int32, msg *Sync) { + t.beginTrace(sender, encodedLen, "Sync") + t.finishTrace() +} + +func (t *tracer) traceTerminate(sender byte, encodedLen int32, msg *Terminate) { + t.beginTrace(sender, encodedLen, "Terminate") + t.finishTrace() +} + +func (t *tracer) beginTrace(sender byte, encodedLen int32, msgType string) { + if !t.SuppressTimestamps { + now := time.Now() + t.buf.WriteString(now.Format("2006-01-02 15:04:05.000000")) + t.buf.WriteByte('\t') + } + + t.buf.WriteByte(sender) + t.buf.WriteByte('\t') + t.buf.WriteString(msgType) + t.buf.WriteByte('\t') + t.buf.WriteString(strconv.FormatInt(int64(encodedLen), 10)) +} + +func (t *tracer) finishTrace() { + t.buf.WriteByte('\n') + t.buf.WriteTo(t.w) + + if t.buf.Cap() > 1024 { + t.buf = &bytes.Buffer{} + } else { + t.buf.Reset() + } +} + +// traceDoubleQuotedString returns t.buf as a double-quoted string without any escaping. It is roughly equivalent to +// pqTraceOutputString in libpq. +func traceDoubleQuotedString(buf []byte) string { + return `"` + string(buf) + `"` +} + +// traceSingleQuotedString returns buf as a single-quoted string with non-printable characters hex-escaped. It is +// roughly equivalent to pqTraceOutputNchar in libpq. +func traceSingleQuotedString(buf []byte) string { + sb := &strings.Builder{} + + sb.WriteByte('\'') + for _, b := range buf { + if b < 32 || b > 126 { + fmt.Fprintf(sb, `\x%x`, b) + } else { + sb.WriteByte(b) + } + } + sb.WriteByte('\'') + + return sb.String() +} diff --git a/pgproto3/trace_test.go b/pgproto3/trace_test.go new file mode 100644 index 00000000..904bbdfb --- /dev/null +++ b/pgproto3/trace_test.go @@ -0,0 +1,56 @@ +package pgproto3_test + +import ( + "bytes" + "context" + "os" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgproto3" + "github.com/stretchr/testify/require" +) + +func TestTrace(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer conn.Close(ctx) + + if conn.ParameterStatus("crdb_version") != "" { + t.Skip("Skipping message trace on CockroachDB as it varies slightly from PostgreSQL") + } + + traceOutput := &bytes.Buffer{} + conn.Frontend().Trace(traceOutput, pgproto3.TracerOptions{ + SuppressTimestamps: true, + RegressMode: true, + }) + + result := conn.ExecParams(ctx, "select n from generate_series(1,5) n", nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + + expected := `F Parse 45 "" "select n from generate_series(1,5) n" 0 +F Bind 13 "" "" 0 0 0 +F Describe 7 P "" +F Execute 10 "" 0 +F Sync 5 +B ParseComplete 5 +B BindComplete 5 +B RowDescription 27 1 "n" 0 0 23 4 -1 0 +B DataRow 12 1 1 '1' +B DataRow 12 1 1 '2' +B DataRow 12 1 1 '3' +B DataRow 12 1 1 '4' +B DataRow 12 1 1 '5' +B CommandComplete 14 "SELECT 5" +B ReadyForQuery 6 I +` + + require.Equal(t, expected, traceOutput.String()) +} diff --git a/pgtype/array.go b/pgtype/array.go new file mode 100644 index 00000000..0fa4c129 --- /dev/null +++ b/pgtype/array.go @@ -0,0 +1,481 @@ +package pgtype + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "reflect" + "strconv" + "strings" + "unicode" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +// Information on the internals of PostgreSQL arrays can be found in +// src/include/utils/array.h and src/backend/utils/adt/arrayfuncs.c. Of +// particular interest is the array_send function. + +type arrayHeader struct { + ContainsNull bool + ElementOID uint32 + Dimensions []ArrayDimension +} + +type ArrayDimension struct { + Length int32 + LowerBound int32 +} + +// cardinality returns the number of elements in an array of dimensions size. +func cardinality(dimensions []ArrayDimension) int { + if len(dimensions) == 0 { + return 0 + } + + elementCount := int(dimensions[0].Length) + for _, d := range dimensions[1:] { + elementCount *= int(d.Length) + } + + return elementCount +} + +func (dst *arrayHeader) DecodeBinary(m *Map, src []byte) (int, error) { + if len(src) < 12 { + return 0, fmt.Errorf("array header too short: %d", len(src)) + } + + rp := 0 + + numDims := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + dst.ContainsNull = binary.BigEndian.Uint32(src[rp:]) == 1 + rp += 4 + + dst.ElementOID = binary.BigEndian.Uint32(src[rp:]) + rp += 4 + + dst.Dimensions = make([]ArrayDimension, numDims) + if len(src) < 12+numDims*8 { + return 0, fmt.Errorf("array header too short for %d dimensions: %d", numDims, len(src)) + } + for i := range dst.Dimensions { + dst.Dimensions[i].Length = int32(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + dst.Dimensions[i].LowerBound = int32(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + } + + return rp, nil +} + +func (src arrayHeader) EncodeBinary(buf []byte) []byte { + buf = pgio.AppendInt32(buf, int32(len(src.Dimensions))) + + var containsNull int32 + if src.ContainsNull { + containsNull = 1 + } + buf = pgio.AppendInt32(buf, containsNull) + + buf = pgio.AppendUint32(buf, src.ElementOID) + + for i := range src.Dimensions { + buf = pgio.AppendInt32(buf, src.Dimensions[i].Length) + buf = pgio.AppendInt32(buf, src.Dimensions[i].LowerBound) + } + + return buf +} + +type untypedTextArray struct { + Elements []string + Quoted []bool + Dimensions []ArrayDimension +} + +func parseUntypedTextArray(src string) (*untypedTextArray, error) { + dst := &untypedTextArray{ + Elements: []string{}, + Quoted: []bool{}, + Dimensions: []ArrayDimension{}, + } + + buf := bytes.NewBufferString(src) + + skipWhitespace(buf) + + r, _, err := buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + var explicitDimensions []ArrayDimension + + // Array has explicit dimensions + if r == '[' { + buf.UnreadRune() + + for { + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + if r == '=' { + break + } else if r != '[' { + return nil, fmt.Errorf("invalid array, expected '[' or '=' got %v", r) + } + + lower, err := arrayParseInteger(buf) + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + if r != ':' { + return nil, fmt.Errorf("invalid array, expected ':' got %v", r) + } + + upper, err := arrayParseInteger(buf) + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + if r != ']' { + return nil, fmt.Errorf("invalid array, expected ']' got %v", r) + } + + explicitDimensions = append(explicitDimensions, ArrayDimension{LowerBound: lower, Length: upper - lower + 1}) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + } + + if r != '{' { + return nil, fmt.Errorf("invalid array, expected '{': %v", err) + } + + implicitDimensions := []ArrayDimension{{LowerBound: 1, Length: 0}} + + // Consume all initial opening brackets. This provides number of dimensions. + for { + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + if r == '{' { + implicitDimensions[len(implicitDimensions)-1].Length = 1 + implicitDimensions = append(implicitDimensions, ArrayDimension{LowerBound: 1}) + } else { + buf.UnreadRune() + break + } + } + currentDim := len(implicitDimensions) - 1 + counterDim := currentDim + + for { + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + switch r { + case '{': + if currentDim == counterDim { + implicitDimensions[currentDim].Length++ + } + currentDim++ + case ',': + case '}': + currentDim-- + if currentDim < counterDim { + counterDim = currentDim + } + default: + buf.UnreadRune() + value, quoted, err := arrayParseValue(buf) + if err != nil { + return nil, fmt.Errorf("invalid array value: %v", err) + } + if currentDim == counterDim { + implicitDimensions[currentDim].Length++ + } + dst.Quoted = append(dst.Quoted, quoted) + dst.Elements = append(dst.Elements, value) + } + + if currentDim < 0 { + break + } + } + + skipWhitespace(buf) + + if buf.Len() > 0 { + return nil, fmt.Errorf("unexpected trailing data: %v", buf.String()) + } + + if len(dst.Elements) == 0 { + } else if len(explicitDimensions) > 0 { + dst.Dimensions = explicitDimensions + } else { + dst.Dimensions = implicitDimensions + } + + return dst, nil +} + +func skipWhitespace(buf *bytes.Buffer) { + var r rune + var err error + for r, _, _ = buf.ReadRune(); unicode.IsSpace(r); r, _, _ = buf.ReadRune() { + } + + if err != io.EOF { + buf.UnreadRune() + } +} + +func arrayParseValue(buf *bytes.Buffer) (string, bool, error) { + r, _, err := buf.ReadRune() + if err != nil { + return "", false, err + } + if r == '"' { + return arrayParseQuotedValue(buf) + } + buf.UnreadRune() + + s := &bytes.Buffer{} + + for { + r, _, err := buf.ReadRune() + if err != nil { + return "", false, err + } + + switch r { + case ',', '}': + buf.UnreadRune() + return s.String(), false, nil + } + + s.WriteRune(r) + } +} + +func arrayParseQuotedValue(buf *bytes.Buffer) (string, bool, error) { + s := &bytes.Buffer{} + + for { + r, _, err := buf.ReadRune() + if err != nil { + return "", false, err + } + + switch r { + case '\\': + r, _, err = buf.ReadRune() + if err != nil { + return "", false, err + } + case '"': + r, _, err = buf.ReadRune() + if err != nil { + return "", false, err + } + buf.UnreadRune() + return s.String(), true, nil + } + s.WriteRune(r) + } +} + +func arrayParseInteger(buf *bytes.Buffer) (int32, error) { + s := &bytes.Buffer{} + + for { + r, _, err := buf.ReadRune() + if err != nil { + return 0, err + } + + if ('0' <= r && r <= '9') || r == '-' { + s.WriteRune(r) + } else { + buf.UnreadRune() + n, err := strconv.ParseInt(s.String(), 10, 32) + if err != nil { + return 0, err + } + return int32(n), nil + } + } +} + +func encodeTextArrayDimensions(buf []byte, dimensions []ArrayDimension) []byte { + var customDimensions bool + for _, dim := range dimensions { + if dim.LowerBound != 1 { + customDimensions = true + } + } + + if !customDimensions { + return buf + } + + for _, dim := range dimensions { + buf = append(buf, '[') + buf = append(buf, strconv.FormatInt(int64(dim.LowerBound), 10)...) + buf = append(buf, ':') + buf = append(buf, strconv.FormatInt(int64(dim.LowerBound+dim.Length-1), 10)...) + buf = append(buf, ']') + } + + return append(buf, '=') +} + +var quoteArrayReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`) + +func quoteArrayElement(src string) string { + return `"` + quoteArrayReplacer.Replace(src) + `"` +} + +func isSpace(ch byte) bool { + // see https://github.com/postgres/postgres/blob/REL_12_STABLE/src/backend/parser/scansup.c#L224 + return ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' || ch == '\f' +} + +func quoteArrayElementIfNeeded(src string) string { + if src == "" || (len(src) == 4 && strings.ToLower(src) == "null") || isSpace(src[0]) || isSpace(src[len(src)-1]) || strings.ContainsAny(src, `{},"\`) { + return quoteArrayElement(src) + } + return src +} + +func findDimensionsFromValue(value reflect.Value, dimensions []ArrayDimension, elementsLength int) ([]ArrayDimension, int, bool) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + length := value.Len() + if 0 == elementsLength { + elementsLength = length + } else { + elementsLength *= length + } + dimensions = append(dimensions, ArrayDimension{Length: int32(length), LowerBound: 1}) + for i := 0; i < length; i++ { + if d, l, ok := findDimensionsFromValue(value.Index(i), dimensions, elementsLength); ok { + return d, l, true + } + } + } + return dimensions, elementsLength, true +} + +// Array represents a PostgreSQL array for T. It implements the ArrayGetter and ArraySetter interfaces. It preserves +// PostgreSQL dimensions and custom lower bounds. Use FlatArray if these are not needed. +type Array[T any] struct { + Elements []T + Dims []ArrayDimension + Valid bool +} + +func (a Array[T]) Dimensions() []ArrayDimension { + return a.Dims +} + +func (a Array[T]) Index(i int) any { + return a.Elements[i] +} + +func (a Array[T]) IndexType() any { + var el T + return el +} + +func (a *Array[T]) SetDimensions(dimensions []ArrayDimension) error { + if dimensions == nil { + *a = Array[T]{} + return nil + } + + elementCount := cardinality(dimensions) + *a = Array[T]{ + Elements: make([]T, elementCount), + Dims: dimensions, + Valid: true, + } + + return nil +} + +func (a Array[T]) ScanIndex(i int) any { + return &a.Elements[i] +} + +func (a Array[T]) ScanIndexType() any { + return new(T) +} + +// FlatArray implements the ArrayGetter and ArraySetter interfaces for any slice of T. It ignores PostgreSQL dimensions +// and custom lower bounds. Use Array to preserve these. +type FlatArray[T any] []T + +func (a FlatArray[T]) Dimensions() []ArrayDimension { + if a == nil { + return nil + } + + return []ArrayDimension{{Length: int32(len(a)), LowerBound: 1}} +} + +func (a FlatArray[T]) Index(i int) any { + return a[i] +} + +func (a FlatArray[T]) IndexType() any { + var el T + return el +} + +func (a *FlatArray[T]) SetDimensions(dimensions []ArrayDimension) error { + if dimensions == nil { + *a = nil + return nil + } + + elementCount := cardinality(dimensions) + *a = make(FlatArray[T], elementCount) + return nil +} + +func (a FlatArray[T]) ScanIndex(i int) any { + return &a[i] +} + +func (a FlatArray[T]) ScanIndexType() any { + return new(T) +} diff --git a/pgtype/array_codec.go b/pgtype/array_codec.go new file mode 100644 index 00000000..dae12039 --- /dev/null +++ b/pgtype/array_codec.go @@ -0,0 +1,395 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + + "github.com/jackc/pgx/v5/internal/anynil" + "github.com/jackc/pgx/v5/internal/pgio" +) + +// ArrayGetter is a type that can be converted into a PostgreSQL array. +type ArrayGetter interface { + // Dimensions returns the array dimensions. If array is nil then nil is returned. + Dimensions() []ArrayDimension + + // Index returns the element at i. + Index(i int) any + + // IndexType returns a non-nil scan target of the type Index will return. This is used by ArrayCodec.PlanEncode. + IndexType() any +} + +// ArraySetter is a type can be set from a PostgreSQL array. +type ArraySetter interface { + // SetDimensions prepares the value such that ScanIndex can be called for each element. This will remove any existing + // elements. dimensions may be nil to indicate a NULL array. If unable to exactly preserve dimensions SetDimensions + // may return an error or silently flatten the array dimensions. + SetDimensions(dimensions []ArrayDimension) error + + // ScanIndex returns a value usable as a scan target for i. SetDimensions must be called before ScanIndex. + ScanIndex(i int) any + + // ScanIndexType returns a non-nil scan target of the type ScanIndex will return. This is used by + // ArrayCodec.PlanScan. + ScanIndexType() any +} + +// ArrayCodec is a codec for any array type. +type ArrayCodec struct { + ElementType *Type +} + +func (c *ArrayCodec) FormatSupported(format int16) bool { + return c.ElementType.Codec.FormatSupported(format) +} + +func (c *ArrayCodec) PreferredFormat() int16 { + return c.ElementType.Codec.PreferredFormat() +} + +func (c *ArrayCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + arrayValuer, ok := value.(ArrayGetter) + if !ok { + return nil + } + + elementType := arrayValuer.IndexType() + + elementEncodePlan := m.PlanEncode(c.ElementType.OID, format, elementType) + if elementEncodePlan == nil { + return nil + } + + switch format { + case BinaryFormatCode: + return &encodePlanArrayCodecBinary{ac: c, m: m, oid: oid} + case TextFormatCode: + return &encodePlanArrayCodecText{ac: c, m: m, oid: oid} + } + + return nil +} + +type encodePlanArrayCodecText struct { + ac *ArrayCodec + m *Map + oid uint32 +} + +func (p *encodePlanArrayCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { + array := value.(ArrayGetter) + + dimensions := array.Dimensions() + if dimensions == nil { + return nil, nil + } + + elementCount := cardinality(dimensions) + if elementCount == 0 { + return append(buf, '{', '}'), nil + } + + buf = encodeTextArrayDimensions(buf, dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(dimensions)) + dimElemCounts[len(dimensions)-1] = int(dimensions[len(dimensions)-1].Length) + for i := len(dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(dimensions[i].Length) * dimElemCounts[i+1] + } + + var encodePlan EncodePlan + var lastElemType reflect.Type + inElemBuf := make([]byte, 0, 32) + for i := 0; i < elementCount; i++ { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elem := array.Index(i) + var elemBuf []byte + if elem != nil { + elemType := reflect.TypeOf(elem) + if lastElemType != elemType { + lastElemType = elemType + encodePlan = p.m.PlanEncode(p.ac.ElementType.OID, TextFormatCode, elem) + if encodePlan == nil { + return nil, fmt.Errorf("unable to encode %v", array.Index(i)) + } + } + elemBuf, err = encodePlan.Encode(elem, inElemBuf) + if err != nil { + return nil, err + } + } + + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, quoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +type encodePlanArrayCodecBinary struct { + ac *ArrayCodec + m *Map + oid uint32 +} + +func (p *encodePlanArrayCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + array := value.(ArrayGetter) + + dimensions := array.Dimensions() + if dimensions == nil { + return nil, nil + } + + arrayHeader := arrayHeader{ + Dimensions: dimensions, + ElementOID: p.ac.ElementType.OID, + } + + containsNullIndex := len(buf) + 4 + + buf = arrayHeader.EncodeBinary(buf) + + elementCount := cardinality(dimensions) + + var encodePlan EncodePlan + var lastElemType reflect.Type + for i := 0; i < elementCount; i++ { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elem := array.Index(i) + var elemBuf []byte + if elem != nil { + elemType := reflect.TypeOf(elem) + if lastElemType != elemType { + lastElemType = elemType + encodePlan = p.m.PlanEncode(p.ac.ElementType.OID, BinaryFormatCode, elem) + if encodePlan == nil { + return nil, fmt.Errorf("unable to encode %v", array.Index(i)) + } + } + elemBuf, err = encodePlan.Encode(elem, buf) + if err != nil { + return nil, err + } + } + + if elemBuf == nil { + pgio.SetInt32(buf[containsNullIndex:], 1) + } else { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +func (c *ArrayCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + arrayScanner, ok := target.(ArraySetter) + if !ok { + return nil + } + + // target / arrayScanner might be a pointer to a nil. If it is create one so we can call ScanIndexType to plan the + // scan of the elements. + if anynil.Is(target) { + arrayScanner = reflect.New(reflect.TypeOf(target).Elem()).Interface().(ArraySetter) + } + + elementType := arrayScanner.ScanIndexType() + + elementScanPlan := m.PlanScan(c.ElementType.OID, format, elementType) + if _, ok := elementScanPlan.(*scanPlanFail); ok { + return nil + } + + return &scanPlanArrayCodec{ + arrayCodec: c, + m: m, + oid: oid, + formatCode: format, + } +} + +func (c *ArrayCodec) decodeBinary(m *Map, arrayOID uint32, src []byte, array ArraySetter) error { + var arrayHeader arrayHeader + rp, err := arrayHeader.DecodeBinary(m, src) + if err != nil { + return err + } + + err = array.SetDimensions(arrayHeader.Dimensions) + if err != nil { + return err + } + + elementCount := cardinality(arrayHeader.Dimensions) + if elementCount == 0 { + return nil + } + + elementScanPlan := c.ElementType.Codec.PlanScan(m, c.ElementType.OID, BinaryFormatCode, array.ScanIndex(0)) + if elementScanPlan == nil { + elementScanPlan = m.PlanScan(c.ElementType.OID, BinaryFormatCode, array.ScanIndex(0)) + } + + for i := 0; i < elementCount; i++ { + elem := array.ScanIndex(i) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elementScanPlan.Scan(elemSrc, elem) + if err != nil { + return fmt.Errorf("failed to scan array element %d: %w", i, err) + } + } + + return nil +} + +func (c *ArrayCodec) decodeText(m *Map, arrayOID uint32, src []byte, array ArraySetter) error { + uta, err := parseUntypedTextArray(string(src)) + if err != nil { + return err + } + + err = array.SetDimensions(uta.Dimensions) + if err != nil { + return err + } + + if len(uta.Elements) == 0 { + return nil + } + + elementScanPlan := c.ElementType.Codec.PlanScan(m, c.ElementType.OID, TextFormatCode, array.ScanIndex(0)) + if elementScanPlan == nil { + elementScanPlan = m.PlanScan(c.ElementType.OID, TextFormatCode, array.ScanIndex(0)) + } + + for i, s := range uta.Elements { + elem := array.ScanIndex(i) + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + + err = elementScanPlan.Scan(elemSrc, elem) + if err != nil { + return err + } + } + + return nil +} + +type scanPlanArrayCodec struct { + arrayCodec *ArrayCodec + m *Map + oid uint32 + formatCode int16 + elementScanPlan ScanPlan +} + +func (spac *scanPlanArrayCodec) Scan(src []byte, dst any) error { + c := spac.arrayCodec + m := spac.m + oid := spac.oid + formatCode := spac.formatCode + + array := dst.(ArraySetter) + + if src == nil { + return array.SetDimensions(nil) + } + + switch formatCode { + case BinaryFormatCode: + return c.decodeBinary(m, oid, src, array) + case TextFormatCode: + return c.decodeText(m, oid, src, array) + default: + return fmt.Errorf("unknown format code %d", formatCode) + } +} + +func (c *ArrayCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + switch format { + case TextFormatCode: + return string(src), nil + case BinaryFormatCode: + buf := make([]byte, len(src)) + copy(buf, src) + return buf, nil + default: + return nil, fmt.Errorf("unknown format code %d", format) + } +} + +func (c *ArrayCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var slice []any + err := m.PlanScan(oid, format, &slice).Scan(src, &slice) + return slice, err +} + +func isRagged(slice reflect.Value) bool { + if slice.Type().Elem().Kind() != reflect.Slice { + return false + } + + sliceLen := slice.Len() + innerLen := 0 + for i := 0; i < sliceLen; i++ { + if i == 0 { + innerLen = slice.Index(i).Len() + } else { + if slice.Index(i).Len() != innerLen { + return true + } + } + if isRagged(slice.Index(i)) { + return true + } + } + + return false +} diff --git a/pgtype/array_codec_test.go b/pgtype/array_codec_test.go new file mode 100644 index 00000000..a558d0fc --- /dev/null +++ b/pgtype/array_codec_test.go @@ -0,0 +1,274 @@ +package pgtype_test + +import ( + "context" + "encoding/hex" + "strings" + "testing" + + pgx "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestArrayCodec(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + for i, tt := range []struct { + expected any + }{ + {[]int16(nil)}, + {[]int16{}}, + {[]int16{1, 2, 3}}, + } { + var actual []int16 + err := conn.QueryRow( + ctx, + "select $1::smallint[]", + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } + + newInt16 := func(n int16) *int16 { return &n } + + for i, tt := range []struct { + expected any + }{ + {[]*int16{newInt16(1), nil, newInt16(3), nil, newInt16(5)}}, + } { + var actual []*int16 + err := conn.QueryRow( + ctx, + "select $1::smallint[]", + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } + }) +} + +func TestArrayCodecFlatArray(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + for i, tt := range []struct { + expected any + }{ + {pgtype.FlatArray[int32](nil)}, + {pgtype.FlatArray[int32]{}}, + {pgtype.FlatArray[int32]{1, 2, 3}}, + } { + var actual pgtype.FlatArray[int32] + err := conn.QueryRow( + ctx, + "select $1::int[]", + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } + }) +} + +func TestArrayCodecArray(t *testing.T) { + ctr := defaultConnTestRunner + ctr.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "Server does not support multi-dimensional arrays") + } + + ctr.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + for i, tt := range []struct { + expected any + }{ + {pgtype.Array[int32]{ + Elements: []int32{1, 2, 3, 4}, + Dims: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 2}, + {Length: 2, LowerBound: 2}, + }, + Valid: true, + }}, + } { + var actual pgtype.Array[int32] + err := conn.QueryRow( + ctx, + "select $1::int[]", + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } + }) +} + +func TestArrayCodecAnySlice(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + type _int16Slice []int16 + + for i, tt := range []struct { + expected any + }{ + {_int16Slice(nil)}, + {_int16Slice{}}, + {_int16Slice{1, 2, 3}}, + } { + var actual _int16Slice + err := conn.QueryRow( + ctx, + "select $1::smallint[]", + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } + }) +} + +// https://github.com/jackc/pgx/issues/1273#issuecomment-1218262703 +func TestArrayCodecSliceArgConversion(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + arg := []string{ + "3ad95bfd-ecea-4032-83c3-0c823cafb372", + "951baf11-c0cc-4afc-a779-abff0611dbf1", + "8327f244-7e2f-45e7-a10b-fbdc9d6f3378", + } + + var expected []pgtype.UUID + + for _, s := range arg { + buf, err := hex.DecodeString(strings.ReplaceAll(s, "-", "")) + require.NoError(t, err) + var u pgtype.UUID + copy(u.Bytes[:], buf) + u.Valid = true + expected = append(expected, u) + } + + var actual []pgtype.UUID + err := conn.QueryRow( + ctx, + "select $1::uuid[]", + arg, + ).Scan(&actual) + require.NoError(t, err) + require.Equal(t, expected, actual) + }) +} + +func TestArrayCodecDecodeValue(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + for _, tt := range []struct { + sql string + expected any + }{ + { + sql: `select '{}'::int4[]`, + expected: []any{}, + }, + { + sql: `select '{1,2}'::int8[]`, + expected: []any{int64(1), int64(2)}, + }, + { + sql: `select '{foo,bar}'::text[]`, + expected: []any{"foo", "bar"}, + }, + } { + t.Run(tt.sql, func(t *testing.T) { + rows, err := conn.Query(ctx, tt.sql) + require.NoError(t, err) + + for rows.Next() { + values, err := rows.Values() + require.NoError(t, err) + require.Len(t, values, 1) + require.Equal(t, tt.expected, values[0]) + } + + require.NoError(t, rows.Err()) + }) + } + }) +} + +func TestArrayCodecScanMultipleDimensions(t *testing.T) { + skipCockroachDB(t, "Server does not support nested arrays (https://github.com/cockroachdb/cockroach/issues/36815)") + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + + rows, err := conn.Query(ctx, `select '{{1,2,3,4}, {5,6,7,8}, {9,10,11,12}}'::int4[]`) + require.NoError(t, err) + + for rows.Next() { + var ss [][]int32 + err := rows.Scan(&ss) + require.NoError(t, err) + require.Equal(t, [][]int32{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, ss) + } + + require.NoError(t, rows.Err()) + }) +} + +func TestArrayCodecScanMultipleDimensionsEmpty(t *testing.T) { + skipCockroachDB(t, "Server does not support nested arrays (https://github.com/cockroachdb/cockroach/issues/36815)") + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, err := conn.Query(ctx, `select '{}'::int4[]`) + require.NoError(t, err) + + for rows.Next() { + var ss [][]int32 + err := rows.Scan(&ss) + require.NoError(t, err) + require.Equal(t, [][]int32{}, ss) + } + + require.NoError(t, rows.Err()) + }) +} + +func TestArrayCodecScanWrongMultipleDimensions(t *testing.T) { + skipCockroachDB(t, "Server does not support nested arrays (https://github.com/cockroachdb/cockroach/issues/36815)") + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, err := conn.Query(ctx, `select '{{1,2,3,4}, {5,6,7,8}, {9,10,11,12}}'::int4[]`) + require.NoError(t, err) + + for rows.Next() { + var ss [][][]int32 + err := rows.Scan(&ss) + require.Error(t, err, "can't scan into dest[0]: PostgreSQL array has 2 dimensions but slice has 3 dimensions") + } + }) +} + +func TestArrayCodecEncodeMultipleDimensions(t *testing.T) { + skipCockroachDB(t, "Server does not support nested arrays (https://github.com/cockroachdb/cockroach/issues/36815)") + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, err := conn.Query(ctx, `select $1::int4[]`, [][]int32{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}) + require.NoError(t, err) + + for rows.Next() { + var ss [][]int32 + err := rows.Scan(&ss) + require.NoError(t, err) + require.Equal(t, [][]int32{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, ss) + } + + require.NoError(t, rows.Err()) + }) +} + +func TestArrayCodecEncodeMultipleDimensionsRagged(t *testing.T) { + skipCockroachDB(t, "Server does not support nested arrays (https://github.com/cockroachdb/cockroach/issues/36815)") + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, err := conn.Query(ctx, `select $1::int4[]`, [][]int32{{1, 2, 3, 4}, {5}, {9, 10, 11, 12}}) + require.Error(t, err, "cannot convert [][]int32 to ArrayGetter because it is a ragged multi-dimensional") + defer rows.Close() + }) +} diff --git a/pgtype/array_test.go b/pgtype/array_test.go new file mode 100644 index 00000000..f246b346 --- /dev/null +++ b/pgtype/array_test.go @@ -0,0 +1,121 @@ +package pgtype + +import ( + "reflect" + "testing" +) + +func TestParseUntypedTextArray(t *testing.T) { + tests := []struct { + source string + result untypedTextArray + }{ + { + source: "{}", + result: untypedTextArray{ + Elements: []string{}, + Quoted: []bool{}, + Dimensions: []ArrayDimension{}, + }, + }, + { + source: "{1}", + result: untypedTextArray{ + Elements: []string{"1"}, + Quoted: []bool{false}, + Dimensions: []ArrayDimension{{Length: 1, LowerBound: 1}}, + }, + }, + { + source: "{a,b}", + result: untypedTextArray{ + Elements: []string{"a", "b"}, + Quoted: []bool{false, false}, + Dimensions: []ArrayDimension{{Length: 2, LowerBound: 1}}, + }, + }, + { + source: `{"NULL"}`, + result: untypedTextArray{ + Elements: []string{"NULL"}, + Quoted: []bool{true}, + Dimensions: []ArrayDimension{{Length: 1, LowerBound: 1}}, + }, + }, + { + source: `{""}`, + result: untypedTextArray{ + Elements: []string{""}, + Quoted: []bool{true}, + Dimensions: []ArrayDimension{{Length: 1, LowerBound: 1}}, + }, + }, + { + source: `{"He said, \"Hello.\""}`, + result: untypedTextArray{ + Elements: []string{`He said, "Hello."`}, + Quoted: []bool{true}, + Dimensions: []ArrayDimension{{Length: 1, LowerBound: 1}}, + }, + }, + { + source: "{{a,b},{c,d},{e,f}}", + result: untypedTextArray{ + Elements: []string{"a", "b", "c", "d", "e", "f"}, + Quoted: []bool{false, false, false, false, false, false}, + Dimensions: []ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + }, + }, + { + source: "{{{a,b},{c,d},{e,f}},{{a,b},{c,d},{e,f}}}", + result: untypedTextArray{ + Elements: []string{"a", "b", "c", "d", "e", "f", "a", "b", "c", "d", "e", "f"}, + Quoted: []bool{false, false, false, false, false, false, false, false, false, false, false, false}, + Dimensions: []ArrayDimension{ + {Length: 2, LowerBound: 1}, + {Length: 3, LowerBound: 1}, + {Length: 2, LowerBound: 1}, + }, + }, + }, + { + source: "[4:4]={1}", + result: untypedTextArray{ + Elements: []string{"1"}, + Quoted: []bool{false}, + Dimensions: []ArrayDimension{{Length: 1, LowerBound: 4}}, + }, + }, + { + source: "[4:5][2:3]={{a,b},{c,d}}", + result: untypedTextArray{ + Elements: []string{"a", "b", "c", "d"}, + Quoted: []bool{false, false, false, false}, + Dimensions: []ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + }, + }, + { + source: "[-4:-2]={1,2,3}", + result: untypedTextArray{ + Elements: []string{"1", "2", "3"}, + Quoted: []bool{false, false, false}, + Dimensions: []ArrayDimension{{Length: 3, LowerBound: -4}}, + }, + }, + } + + for i, tt := range tests { + r, err := parseUntypedTextArray(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + continue + } + + if !reflect.DeepEqual(*r, tt.result) { + t.Errorf("%d: expected %+v to be parsed to %+v, but it was %+v", i, tt.source, tt.result, *r) + } + } +} diff --git a/pgtype/bits.go b/pgtype/bits.go new file mode 100644 index 00000000..30558118 --- /dev/null +++ b/pgtype/bits.go @@ -0,0 +1,208 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type BitsScanner interface { + ScanBits(v Bits) error +} + +type BitsValuer interface { + BitsValue() (Bits, error) +} + +// Bits represents the PostgreSQL bit and varbit types. +type Bits struct { + Bytes []byte + Len int32 // Number of bits + Valid bool +} + +func (b *Bits) ScanBits(v Bits) error { + *b = v + return nil +} + +func (b Bits) BitsValue() (Bits, error) { + return b, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Bits) Scan(src any) error { + if src == nil { + *dst = Bits{} + return nil + } + + switch src := src.(type) { + case string: + return scanPlanTextAnyToBitsScanner{}.Scan([]byte(src), dst) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Bits) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + + buf, err := BitsCodec{}.PlanEncode(nil, 0, TextFormatCode, src).Encode(src, nil) + if err != nil { + return nil, err + } + return string(buf), err +} + +type BitsCodec struct{} + +func (BitsCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (BitsCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (BitsCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(BitsValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return encodePlanBitsCodecBinary{} + case TextFormatCode: + return encodePlanBitsCodecText{} + } + + return nil +} + +type encodePlanBitsCodecBinary struct{} + +func (encodePlanBitsCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + bits, err := value.(BitsValuer).BitsValue() + if err != nil { + return nil, err + } + + if !bits.Valid { + return nil, nil + } + + buf = pgio.AppendInt32(buf, bits.Len) + return append(buf, bits.Bytes...), nil +} + +type encodePlanBitsCodecText struct{} + +func (encodePlanBitsCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { + bits, err := value.(BitsValuer).BitsValue() + if err != nil { + return nil, err + } + + if !bits.Valid { + return nil, nil + } + + for i := int32(0); i < bits.Len; i++ { + byteIdx := i / 8 + bitMask := byte(128 >> byte(i%8)) + char := byte('0') + if bits.Bytes[byteIdx]&bitMask > 0 { + char = '1' + } + buf = append(buf, char) + } + + return buf, nil +} + +func (BitsCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case BitsScanner: + return scanPlanBinaryBitsToBitsScanner{} + } + case TextFormatCode: + switch target.(type) { + case BitsScanner: + return scanPlanTextAnyToBitsScanner{} + } + } + + return nil +} + +func (c BitsCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) +} + +func (c BitsCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var box Bits + err := codecScan(c, m, oid, format, src, &box) + if err != nil { + return nil, err + } + return box, nil +} + +type scanPlanBinaryBitsToBitsScanner struct{} + +func (scanPlanBinaryBitsToBitsScanner) Scan(src []byte, dst any) error { + scanner := (dst).(BitsScanner) + + if src == nil { + return scanner.ScanBits(Bits{}) + } + + if len(src) < 4 { + return fmt.Errorf("invalid length for bit/varbit: %v", len(src)) + } + + bitLen := int32(binary.BigEndian.Uint32(src)) + rp := 4 + + return scanner.ScanBits(Bits{Bytes: src[rp:], Len: bitLen, Valid: true}) +} + +type scanPlanTextAnyToBitsScanner struct{} + +func (scanPlanTextAnyToBitsScanner) Scan(src []byte, dst any) error { + scanner := (dst).(BitsScanner) + + if src == nil { + return scanner.ScanBits(Bits{}) + } + + bitLen := len(src) + byteLen := bitLen / 8 + if bitLen%8 > 0 { + byteLen++ + } + buf := make([]byte, byteLen) + + for i, b := range src { + if b == '1' { + byteIdx := i / 8 + bitIdx := uint(i % 8) + buf[byteIdx] = buf[byteIdx] | (128 >> bitIdx) + } + } + + return scanner.ScanBits(Bits{Bytes: buf, Len: int32(bitLen), Valid: true}) +} diff --git a/pgtype/bits_test.go b/pgtype/bits_test.go new file mode 100644 index 00000000..767f0d2b --- /dev/null +++ b/pgtype/bits_test.go @@ -0,0 +1,57 @@ +package pgtype_test + +import ( + "bytes" + "context" + "testing" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" +) + +func isExpectedEqBits(a any) func(any) bool { + return func(v any) bool { + ab := a.(pgtype.Bits) + vb := v.(pgtype.Bits) + return bytes.Compare(ab.Bytes, vb.Bytes) == 0 && ab.Len == vb.Len && ab.Valid == vb.Valid + } +} + +func TestBitsCodecBit(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "bit(40)", []pgxtest.ValueRoundTripTest{ + { + pgtype.Bits{Bytes: []byte{0, 0, 0, 0, 0}, Len: 40, Valid: true}, + new(pgtype.Bits), + isExpectedEqBits(pgtype.Bits{Bytes: []byte{0, 0, 0, 0, 0}, Len: 40, Valid: true}), + }, + { + pgtype.Bits{Bytes: []byte{0, 1, 128, 254, 255}, Len: 40, Valid: true}, + new(pgtype.Bits), + isExpectedEqBits(pgtype.Bits{Bytes: []byte{0, 1, 128, 254, 255}, Len: 40, Valid: true}), + }, + {pgtype.Bits{}, new(pgtype.Bits), isExpectedEqBits(pgtype.Bits{})}, + {nil, new(pgtype.Bits), isExpectedEqBits(pgtype.Bits{})}, + }) +} + +func TestBitsCodecVarbit(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "varbit", []pgxtest.ValueRoundTripTest{ + { + pgtype.Bits{Bytes: []byte{}, Len: 0, Valid: true}, + new(pgtype.Bits), + isExpectedEqBits(pgtype.Bits{Bytes: []byte{}, Len: 0, Valid: true}), + }, + { + pgtype.Bits{Bytes: []byte{0, 1, 128, 254, 255}, Len: 40, Valid: true}, + new(pgtype.Bits), + isExpectedEqBits(pgtype.Bits{Bytes: []byte{0, 1, 128, 254, 255}, Len: 40, Valid: true}), + }, + { + pgtype.Bits{Bytes: []byte{0, 1, 128, 254, 128}, Len: 33, Valid: true}, + new(pgtype.Bits), + isExpectedEqBits(pgtype.Bits{Bytes: []byte{0, 1, 128, 254, 128}, Len: 33, Valid: true}), + }, + {pgtype.Bits{}, new(pgtype.Bits), isExpectedEqBits(pgtype.Bits{})}, + {nil, new(pgtype.Bits), isExpectedEqBits(pgtype.Bits{})}, + }) +} diff --git a/pgtype/bool.go b/pgtype/bool.go new file mode 100644 index 00000000..e7be27e2 --- /dev/null +++ b/pgtype/bool.go @@ -0,0 +1,317 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/json" + "fmt" + "strconv" +) + +type BoolScanner interface { + ScanBool(v Bool) error +} + +type BoolValuer interface { + BoolValue() (Bool, error) +} + +type Bool struct { + Bool bool + Valid bool +} + +func (b *Bool) ScanBool(v Bool) error { + *b = v + return nil +} + +func (b Bool) BoolValue() (Bool, error) { + return b, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Bool) Scan(src any) error { + if src == nil { + *dst = Bool{} + return nil + } + + switch src := src.(type) { + case bool: + *dst = Bool{Bool: src, Valid: true} + return nil + case string: + b, err := strconv.ParseBool(src) + if err != nil { + return err + } + *dst = Bool{Bool: b, Valid: true} + return nil + case []byte: + b, err := strconv.ParseBool(string(src)) + if err != nil { + return err + } + *dst = Bool{Bool: b, Valid: true} + return nil + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Bool) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + + return src.Bool, nil +} + +func (src Bool) MarshalJSON() ([]byte, error) { + if !src.Valid { + return []byte("null"), nil + } + + if src.Bool { + return []byte("true"), nil + } else { + return []byte("false"), nil + } +} + +func (dst *Bool) UnmarshalJSON(b []byte) error { + var v *bool + err := json.Unmarshal(b, &v) + if err != nil { + return err + } + + if v == nil { + *dst = Bool{} + } else { + *dst = Bool{Bool: *v, Valid: true} + } + + return nil +} + +type BoolCodec struct{} + +func (BoolCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (BoolCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (BoolCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch format { + case BinaryFormatCode: + switch value.(type) { + case bool: + return encodePlanBoolCodecBinaryBool{} + case BoolValuer: + return encodePlanBoolCodecBinaryBoolValuer{} + } + case TextFormatCode: + switch value.(type) { + case bool: + return encodePlanBoolCodecTextBool{} + case BoolValuer: + return encodePlanBoolCodecTextBoolValuer{} + } + } + + return nil +} + +type encodePlanBoolCodecBinaryBool struct{} + +func (encodePlanBoolCodecBinaryBool) Encode(value any, buf []byte) (newBuf []byte, err error) { + v := value.(bool) + + if v { + buf = append(buf, 1) + } else { + buf = append(buf, 0) + } + + return buf, nil +} + +type encodePlanBoolCodecTextBoolValuer struct{} + +func (encodePlanBoolCodecTextBoolValuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + b, err := value.(BoolValuer).BoolValue() + if err != nil { + return nil, err + } + + if !b.Valid { + return nil, nil + } + + if b.Bool { + buf = append(buf, 't') + } else { + buf = append(buf, 'f') + } + + return buf, nil +} + +type encodePlanBoolCodecBinaryBoolValuer struct{} + +func (encodePlanBoolCodecBinaryBoolValuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + b, err := value.(BoolValuer).BoolValue() + if err != nil { + return nil, err + } + + if !b.Valid { + return nil, nil + } + + if b.Bool { + buf = append(buf, 1) + } else { + buf = append(buf, 0) + } + + return buf, nil +} + +type encodePlanBoolCodecTextBool struct{} + +func (encodePlanBoolCodecTextBool) Encode(value any, buf []byte) (newBuf []byte, err error) { + v := value.(bool) + + if v { + buf = append(buf, 't') + } else { + buf = append(buf, 'f') + } + + return buf, nil +} + +func (BoolCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case *bool: + return scanPlanBinaryBoolToBool{} + case BoolScanner: + return scanPlanBinaryBoolToBoolScanner{} + } + case TextFormatCode: + switch target.(type) { + case *bool: + return scanPlanTextAnyToBool{} + case BoolScanner: + return scanPlanTextAnyToBoolScanner{} + } + } + + return nil +} + +func (c BoolCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return c.DecodeValue(m, oid, format, src) +} + +func (c BoolCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var b bool + err := codecScan(c, m, oid, format, src, &b) + if err != nil { + return nil, err + } + return b, nil +} + +type scanPlanBinaryBoolToBool struct{} + +func (scanPlanBinaryBoolToBool) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 1 { + return fmt.Errorf("invalid length for bool: %v", len(src)) + } + + p, ok := (dst).(*bool) + if !ok { + return ErrScanTargetTypeChanged + } + + *p = src[0] == 1 + + return nil +} + +type scanPlanTextAnyToBool struct{} + +func (scanPlanTextAnyToBool) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 1 { + return fmt.Errorf("invalid length for bool: %v", len(src)) + } + + p, ok := (dst).(*bool) + if !ok { + return ErrScanTargetTypeChanged + } + + *p = src[0] == 't' + + return nil +} + +type scanPlanBinaryBoolToBoolScanner struct{} + +func (scanPlanBinaryBoolToBoolScanner) Scan(src []byte, dst any) error { + s, ok := (dst).(BoolScanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanBool(Bool{}) + } + + if len(src) != 1 { + return fmt.Errorf("invalid length for bool: %v", len(src)) + } + + return s.ScanBool(Bool{Bool: src[0] == 1, Valid: true}) +} + +type scanPlanTextAnyToBoolScanner struct{} + +func (scanPlanTextAnyToBoolScanner) Scan(src []byte, dst any) error { + s, ok := (dst).(BoolScanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanBool(Bool{}) + } + + if len(src) != 1 { + return fmt.Errorf("invalid length for bool: %v", len(src)) + } + + return s.ScanBool(Bool{Bool: src[0] == 't', Valid: true}) +} diff --git a/pgtype/bool_test.go b/pgtype/bool_test.go new file mode 100644 index 00000000..7480471b --- /dev/null +++ b/pgtype/bool_test.go @@ -0,0 +1,62 @@ +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" +) + +func TestBoolCodec(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "bool", []pgxtest.ValueRoundTripTest{ + {true, new(bool), isExpectedEq(true)}, + {false, new(bool), isExpectedEq(false)}, + {true, new(pgtype.Bool), isExpectedEq(pgtype.Bool{Bool: true, Valid: true})}, + {pgtype.Bool{}, new(pgtype.Bool), isExpectedEq(pgtype.Bool{})}, + {nil, new(*bool), isExpectedEq((*bool)(nil))}, + }) +} + +func TestBoolMarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Bool + result string + }{ + {source: pgtype.Bool{}, result: "null"}, + {source: pgtype.Bool{Bool: true, Valid: true}, result: "true"}, + {source: pgtype.Bool{Bool: false, Valid: true}, result: "false"}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + +func TestBoolUnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Bool + }{ + {source: "null", result: pgtype.Bool{}}, + {source: "true", result: pgtype.Bool{Bool: true, Valid: true}}, + {source: "false", result: pgtype.Bool{Bool: false, Valid: true}}, + } + for i, tt := range successfulTests { + var r pgtype.Bool + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/pgtype/box.go b/pgtype/box.go new file mode 100644 index 00000000..887d268b --- /dev/null +++ b/pgtype/box.go @@ -0,0 +1,238 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + "strings" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type BoxScanner interface { + ScanBox(v Box) error +} + +type BoxValuer interface { + BoxValue() (Box, error) +} + +type Box struct { + P [2]Vec2 + Valid bool +} + +func (b *Box) ScanBox(v Box) error { + *b = v + return nil +} + +func (b Box) BoxValue() (Box, error) { + return b, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Box) Scan(src any) error { + if src == nil { + *dst = Box{} + return nil + } + + switch src := src.(type) { + case string: + return scanPlanTextAnyToBoxScanner{}.Scan([]byte(src), dst) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Box) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + + buf, err := BoxCodec{}.PlanEncode(nil, 0, TextFormatCode, src).Encode(src, nil) + if err != nil { + return nil, err + } + return string(buf), err +} + +type BoxCodec struct{} + +func (BoxCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (BoxCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (BoxCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(BoxValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return encodePlanBoxCodecBinary{} + case TextFormatCode: + return encodePlanBoxCodecText{} + } + + return nil +} + +type encodePlanBoxCodecBinary struct{} + +func (encodePlanBoxCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + box, err := value.(BoxValuer).BoxValue() + if err != nil { + return nil, err + } + + if !box.Valid { + return nil, nil + } + + buf = pgio.AppendUint64(buf, math.Float64bits(box.P[0].X)) + buf = pgio.AppendUint64(buf, math.Float64bits(box.P[0].Y)) + buf = pgio.AppendUint64(buf, math.Float64bits(box.P[1].X)) + buf = pgio.AppendUint64(buf, math.Float64bits(box.P[1].Y)) + return buf, nil +} + +type encodePlanBoxCodecText struct{} + +func (encodePlanBoxCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { + box, err := value.(BoxValuer).BoxValue() + if err != nil { + return nil, err + } + + if !box.Valid { + return nil, nil + } + + buf = append(buf, fmt.Sprintf(`(%s,%s),(%s,%s)`, + strconv.FormatFloat(box.P[0].X, 'f', -1, 64), + strconv.FormatFloat(box.P[0].Y, 'f', -1, 64), + strconv.FormatFloat(box.P[1].X, 'f', -1, 64), + strconv.FormatFloat(box.P[1].Y, 'f', -1, 64), + )...) + return buf, nil +} + +func (BoxCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case BoxScanner: + return scanPlanBinaryBoxToBoxScanner{} + } + case TextFormatCode: + switch target.(type) { + case BoxScanner: + return scanPlanTextAnyToBoxScanner{} + } + } + + return nil +} + +type scanPlanBinaryBoxToBoxScanner struct{} + +func (scanPlanBinaryBoxToBoxScanner) Scan(src []byte, dst any) error { + scanner := (dst).(BoxScanner) + + if src == nil { + return scanner.ScanBox(Box{}) + } + + if len(src) != 32 { + return fmt.Errorf("invalid length for Box: %v", len(src)) + } + + x1 := binary.BigEndian.Uint64(src) + y1 := binary.BigEndian.Uint64(src[8:]) + x2 := binary.BigEndian.Uint64(src[16:]) + y2 := binary.BigEndian.Uint64(src[24:]) + + return scanner.ScanBox(Box{ + P: [2]Vec2{ + {math.Float64frombits(x1), math.Float64frombits(y1)}, + {math.Float64frombits(x2), math.Float64frombits(y2)}, + }, + Valid: true, + }) +} + +type scanPlanTextAnyToBoxScanner struct{} + +func (scanPlanTextAnyToBoxScanner) Scan(src []byte, dst any) error { + scanner := (dst).(BoxScanner) + + if src == nil { + return scanner.ScanBox(Box{}) + } + + if len(src) < 11 { + return fmt.Errorf("invalid length for Box: %v", len(src)) + } + + str := string(src[1:]) + + var end int + end = strings.IndexByte(str, ',') + + x1, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1:] + end = strings.IndexByte(str, ')') + + y1, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+3:] + end = strings.IndexByte(str, ',') + + x2, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1 : len(str)-1] + + y2, err := strconv.ParseFloat(str, 64) + if err != nil { + return err + } + + return scanner.ScanBox(Box{P: [2]Vec2{{x1, y1}, {x2, y2}}, Valid: true}) +} + +func (c BoxCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) +} + +func (c BoxCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var box Box + err := codecScan(c, m, oid, format, src, &box) + if err != nil { + return nil, err + } + return box, nil +} diff --git a/pgtype/box_test.go b/pgtype/box_test.go new file mode 100644 index 00000000..3b54c1f8 --- /dev/null +++ b/pgtype/box_test.go @@ -0,0 +1,40 @@ +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" +) + +func TestBoxCodec(t *testing.T) { + skipCockroachDB(t, "Server does not support box type") + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "box", []pgxtest.ValueRoundTripTest{ + { + pgtype.Box{ + P: [2]pgtype.Vec2{{7.1, 5.2345678}, {3.14, 1.678}}, + Valid: true, + }, + new(pgtype.Box), + isExpectedEq(pgtype.Box{ + P: [2]pgtype.Vec2{{7.1, 5.2345678}, {3.14, 1.678}}, + Valid: true, + }), + }, + { + pgtype.Box{ + P: [2]pgtype.Vec2{{7.1, 5.2345678}, {-13.14, -5.234}}, + Valid: true, + }, + new(pgtype.Box), + isExpectedEq(pgtype.Box{ + P: [2]pgtype.Vec2{{7.1, 5.2345678}, {-13.14, -5.234}}, + Valid: true, + }), + }, + {pgtype.Box{}, new(pgtype.Box), isExpectedEq(pgtype.Box{})}, + {nil, new(pgtype.Box), isExpectedEq(pgtype.Box{})}, + }) +} diff --git a/pgtype/builtin_wrappers.go b/pgtype/builtin_wrappers.go new file mode 100644 index 00000000..7a992b09 --- /dev/null +++ b/pgtype/builtin_wrappers.go @@ -0,0 +1,859 @@ +package pgtype + +import ( + "errors" + "fmt" + "math" + "net" + "net/netip" + "reflect" + "time" +) + +type int8Wrapper int8 + +func (w int8Wrapper) SkipUnderlyingTypePlan() {} + +func (w *int8Wrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *int8") + } + + if v.Int64 < math.MinInt8 { + return fmt.Errorf("%d is less than minimum value for int8", v.Int64) + } + if v.Int64 > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for int8", v.Int64) + } + *w = int8Wrapper(v.Int64) + + return nil +} + +func (w int8Wrapper) Int64Value() (Int8, error) { + return Int8{Int64: int64(w), Valid: true}, nil +} + +type int16Wrapper int16 + +func (w int16Wrapper) SkipUnderlyingTypePlan() {} + +func (w *int16Wrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *int16") + } + + if v.Int64 < math.MinInt16 { + return fmt.Errorf("%d is less than minimum value for int16", v.Int64) + } + if v.Int64 > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for int16", v.Int64) + } + *w = int16Wrapper(v.Int64) + + return nil +} + +func (w int16Wrapper) Int64Value() (Int8, error) { + return Int8{Int64: int64(w), Valid: true}, nil +} + +type int32Wrapper int32 + +func (w int32Wrapper) SkipUnderlyingTypePlan() {} + +func (w *int32Wrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *int32") + } + + if v.Int64 < math.MinInt32 { + return fmt.Errorf("%d is less than minimum value for int32", v.Int64) + } + if v.Int64 > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for int32", v.Int64) + } + *w = int32Wrapper(v.Int64) + + return nil +} + +func (w int32Wrapper) Int64Value() (Int8, error) { + return Int8{Int64: int64(w), Valid: true}, nil +} + +type int64Wrapper int64 + +func (w int64Wrapper) SkipUnderlyingTypePlan() {} + +func (w *int64Wrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *int64") + } + + *w = int64Wrapper(v.Int64) + + return nil +} + +func (w int64Wrapper) Int64Value() (Int8, error) { + return Int8{Int64: int64(w), Valid: true}, nil +} + +type intWrapper int + +func (w intWrapper) SkipUnderlyingTypePlan() {} + +func (w *intWrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *int") + } + + if v.Int64 < math.MinInt { + return fmt.Errorf("%d is less than minimum value for int", v.Int64) + } + if v.Int64 > math.MaxInt { + return fmt.Errorf("%d is greater than maximum value for int", v.Int64) + } + + *w = intWrapper(v.Int64) + + return nil +} + +func (w intWrapper) Int64Value() (Int8, error) { + return Int8{Int64: int64(w), Valid: true}, nil +} + +type uint8Wrapper uint8 + +func (w uint8Wrapper) SkipUnderlyingTypePlan() {} + +func (w *uint8Wrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *uint8") + } + + if v.Int64 < 0 { + return fmt.Errorf("%d is less than minimum value for uint8", v.Int64) + } + if v.Int64 > math.MaxUint8 { + return fmt.Errorf("%d is greater than maximum value for uint8", v.Int64) + } + *w = uint8Wrapper(v.Int64) + + return nil +} + +func (w uint8Wrapper) Int64Value() (Int8, error) { + return Int8{Int64: int64(w), Valid: true}, nil +} + +type uint16Wrapper uint16 + +func (w uint16Wrapper) SkipUnderlyingTypePlan() {} + +func (w *uint16Wrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *uint16") + } + + if v.Int64 < 0 { + return fmt.Errorf("%d is less than minimum value for uint16", v.Int64) + } + if v.Int64 > math.MaxUint16 { + return fmt.Errorf("%d is greater than maximum value for uint16", v.Int64) + } + *w = uint16Wrapper(v.Int64) + + return nil +} + +func (w uint16Wrapper) Int64Value() (Int8, error) { + return Int8{Int64: int64(w), Valid: true}, nil +} + +type uint32Wrapper uint32 + +func (w uint32Wrapper) SkipUnderlyingTypePlan() {} + +func (w *uint32Wrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *uint32") + } + + if v.Int64 < 0 { + return fmt.Errorf("%d is less than minimum value for uint32", v.Int64) + } + if v.Int64 > math.MaxUint32 { + return fmt.Errorf("%d is greater than maximum value for uint32", v.Int64) + } + *w = uint32Wrapper(v.Int64) + + return nil +} + +func (w uint32Wrapper) Int64Value() (Int8, error) { + return Int8{Int64: int64(w), Valid: true}, nil +} + +type uint64Wrapper uint64 + +func (w uint64Wrapper) SkipUnderlyingTypePlan() {} + +func (w *uint64Wrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *uint64") + } + + if v.Int64 < 0 { + return fmt.Errorf("%d is less than minimum value for uint64", v.Int64) + } + + *w = uint64Wrapper(v.Int64) + + return nil +} + +func (w uint64Wrapper) Int64Value() (Int8, error) { + if uint64(w) > uint64(math.MaxInt64) { + return Int8{}, fmt.Errorf("%d is greater than maximum value for int64", w) + } + + return Int8{Int64: int64(w), Valid: true}, nil +} + +type uintWrapper uint + +func (w uintWrapper) SkipUnderlyingTypePlan() {} + +func (w *uintWrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *uint64") + } + + if v.Int64 < 0 { + return fmt.Errorf("%d is less than minimum value for uint64", v.Int64) + } + + if uint64(v.Int64) > math.MaxUint { + return fmt.Errorf("%d is greater than maximum value for uint", v.Int64) + } + + *w = uintWrapper(v.Int64) + + return nil +} + +func (w uintWrapper) Int64Value() (Int8, error) { + if uint64(w) > uint64(math.MaxInt64) { + return Int8{}, fmt.Errorf("%d is greater than maximum value for int64", w) + } + + return Int8{Int64: int64(w), Valid: true}, nil +} + +type float32Wrapper float32 + +func (w float32Wrapper) SkipUnderlyingTypePlan() {} + +func (w *float32Wrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *float32") + } + + *w = float32Wrapper(v.Int64) + + return nil +} + +func (w float32Wrapper) Int64Value() (Int8, error) { + if w > math.MaxInt64 { + return Int8{}, fmt.Errorf("%f is greater than maximum value for int64", w) + } + + return Int8{Int64: int64(w), Valid: true}, nil +} + +func (w *float32Wrapper) ScanFloat64(v Float8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *float32") + } + + *w = float32Wrapper(v.Float64) + + return nil +} + +func (w float32Wrapper) Float64Value() (Float8, error) { + return Float8{Float64: float64(w), Valid: true}, nil +} + +type float64Wrapper float64 + +func (w float64Wrapper) SkipUnderlyingTypePlan() {} + +func (w *float64Wrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *float64") + } + + *w = float64Wrapper(v.Int64) + + return nil +} + +func (w float64Wrapper) Int64Value() (Int8, error) { + if w > math.MaxInt64 { + return Int8{}, fmt.Errorf("%f is greater than maximum value for int64", w) + } + + return Int8{Int64: int64(w), Valid: true}, nil +} + +func (w *float64Wrapper) ScanFloat64(v Float8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *float64") + } + + *w = float64Wrapper(v.Float64) + + return nil +} + +func (w float64Wrapper) Float64Value() (Float8, error) { + return Float8{Float64: float64(w), Valid: true}, nil +} + +type stringWrapper string + +func (w stringWrapper) SkipUnderlyingTypePlan() {} + +func (w *stringWrapper) ScanText(v Text) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *string") + } + + *w = stringWrapper(v.String) + return nil +} + +func (w stringWrapper) TextValue() (Text, error) { + return Text{String: string(w), Valid: true}, nil +} + +type timeWrapper time.Time + +func (w *timeWrapper) ScanDate(v Date) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *time.Time") + } + + switch v.InfinityModifier { + case Finite: + *w = timeWrapper(v.Time) + return nil + case Infinity: + return fmt.Errorf("cannot scan Infinity into *time.Time") + case NegativeInfinity: + return fmt.Errorf("cannot scan -Infinity into *time.Time") + default: + return fmt.Errorf("invalid InfinityModifier: %v", v.InfinityModifier) + } +} + +func (w timeWrapper) DateValue() (Date, error) { + return Date{Time: time.Time(w), Valid: true}, nil +} + +func (w *timeWrapper) ScanTimestamp(v Timestamp) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *time.Time") + } + + switch v.InfinityModifier { + case Finite: + *w = timeWrapper(v.Time) + return nil + case Infinity: + return fmt.Errorf("cannot scan Infinity into *time.Time") + case NegativeInfinity: + return fmt.Errorf("cannot scan -Infinity into *time.Time") + default: + return fmt.Errorf("invalid InfinityModifier: %v", v.InfinityModifier) + } +} + +func (w timeWrapper) TimestampValue() (Timestamp, error) { + return Timestamp{Time: time.Time(w), Valid: true}, nil +} + +func (w *timeWrapper) ScanTimestamptz(v Timestamptz) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *time.Time") + } + + switch v.InfinityModifier { + case Finite: + *w = timeWrapper(v.Time) + return nil + case Infinity: + return fmt.Errorf("cannot scan Infinity into *time.Time") + case NegativeInfinity: + return fmt.Errorf("cannot scan -Infinity into *time.Time") + default: + return fmt.Errorf("invalid InfinityModifier: %v", v.InfinityModifier) + } +} + +func (w timeWrapper) TimestamptzValue() (Timestamptz, error) { + return Timestamptz{Time: time.Time(w), Valid: true}, nil +} + +func (w *timeWrapper) ScanTime(v Time) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *time.Time") + } + + // 24:00:00 is max allowed time in PostgreSQL, but time.Time will normalize that to 00:00:00 the next day. + var maxRepresentableByTime int64 = 24*60*60*1000000 - 1 + if v.Microseconds > maxRepresentableByTime { + return fmt.Errorf("%d microseconds cannot be represented as time.Time", v.Microseconds) + } + + usec := v.Microseconds + hours := usec / microsecondsPerHour + usec -= hours * microsecondsPerHour + minutes := usec / microsecondsPerMinute + usec -= minutes * microsecondsPerMinute + seconds := usec / microsecondsPerSecond + usec -= seconds * microsecondsPerSecond + ns := usec * 1000 + *w = timeWrapper(time.Date(2000, 1, 1, int(hours), int(minutes), int(seconds), int(ns), time.UTC)) + return nil +} + +func (w timeWrapper) TimeValue() (Time, error) { + t := time.Time(w) + usec := int64(t.Hour())*microsecondsPerHour + + int64(t.Minute())*microsecondsPerMinute + + int64(t.Second())*microsecondsPerSecond + + int64(t.Nanosecond())/1000 + return Time{Microseconds: usec, Valid: true}, nil +} + +type durationWrapper time.Duration + +func (w durationWrapper) SkipUnderlyingTypePlan() {} + +func (w *durationWrapper) ScanInterval(v Interval) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *time.Interval") + } + + us := int64(v.Months)*microsecondsPerMonth + int64(v.Days)*microsecondsPerDay + v.Microseconds + *w = durationWrapper(time.Duration(us) * time.Microsecond) + return nil +} + +func (w durationWrapper) IntervalValue() (Interval, error) { + return Interval{Microseconds: int64(w) / 1000, Valid: true}, nil +} + +type netIPNetWrapper net.IPNet + +func (w *netIPNetWrapper) ScanNetipPrefix(v netip.Prefix) error { + if !v.IsValid() { + return fmt.Errorf("cannot scan NULL into *net.IPNet") + } + + *w = netIPNetWrapper{ + IP: v.Addr().AsSlice(), + Mask: net.CIDRMask(v.Bits(), v.Addr().BitLen()), + } + + return nil +} +func (w netIPNetWrapper) NetipPrefixValue() (netip.Prefix, error) { + ip, ok := netip.AddrFromSlice(w.IP) + if !ok { + return netip.Prefix{}, errors.New("invalid net.IPNet") + } + + ones, _ := w.Mask.Size() + + return netip.PrefixFrom(ip, ones), nil +} + +type netIPWrapper net.IP + +func (w netIPWrapper) SkipUnderlyingTypePlan() {} + +func (w *netIPWrapper) ScanNetipPrefix(v netip.Prefix) error { + if !v.IsValid() { + *w = nil + return nil + } + + if v.Addr().BitLen() != v.Bits() { + return fmt.Errorf("cannot scan %v to *net.IP", v) + } + + *w = netIPWrapper(v.Addr().AsSlice()) + return nil +} + +func (w netIPWrapper) NetipPrefixValue() (netip.Prefix, error) { + if w == nil { + return netip.Prefix{}, nil + } + + addr, ok := netip.AddrFromSlice([]byte(w)) + if !ok { + return netip.Prefix{}, errors.New("invalid net.IP") + } + + return netip.PrefixFrom(addr, addr.BitLen()), nil +} + +type netipPrefixWrapper netip.Prefix + +func (w *netipPrefixWrapper) ScanNetipPrefix(v netip.Prefix) error { + *w = netipPrefixWrapper(v) + return nil +} + +func (w netipPrefixWrapper) NetipPrefixValue() (netip.Prefix, error) { + return netip.Prefix(w), nil +} + +type netipAddrWrapper netip.Addr + +func (w *netipAddrWrapper) ScanNetipPrefix(v netip.Prefix) error { + if !v.IsValid() { + *w = netipAddrWrapper(netip.Addr{}) + return nil + } + + if v.Addr().BitLen() != v.Bits() { + return fmt.Errorf("cannot scan %v to netip.Addr", v) + } + + *w = netipAddrWrapper(v.Addr()) + + return nil +} + +func (w netipAddrWrapper) NetipPrefixValue() (netip.Prefix, error) { + addr := (netip.Addr)(w) + if !addr.IsValid() { + return netip.Prefix{}, nil + } + + return netip.PrefixFrom(addr, addr.BitLen()), nil +} + +type mapStringToPointerStringWrapper map[string]*string + +func (w *mapStringToPointerStringWrapper) ScanHstore(v Hstore) error { + *w = mapStringToPointerStringWrapper(v) + return nil +} + +func (w mapStringToPointerStringWrapper) HstoreValue() (Hstore, error) { + return Hstore(w), nil +} + +type mapStringToStringWrapper map[string]string + +func (w *mapStringToStringWrapper) ScanHstore(v Hstore) error { + *w = make(mapStringToStringWrapper, len(v)) + for k, v := range v { + if v == nil { + return fmt.Errorf("cannot scan NULL to string") + } + (*w)[k] = *v + } + return nil +} + +func (w mapStringToStringWrapper) HstoreValue() (Hstore, error) { + if w == nil { + return nil, nil + } + + hstore := make(Hstore, len(w)) + for k, v := range w { + s := v + hstore[k] = &s + } + return hstore, nil +} + +type fmtStringerWrapper struct { + s fmt.Stringer +} + +func (w fmtStringerWrapper) TextValue() (Text, error) { + return Text{String: w.s.String(), Valid: true}, nil +} + +type byte16Wrapper [16]byte + +func (w *byte16Wrapper) ScanUUID(v UUID) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *[16]byte") + } + *w = byte16Wrapper(v.Bytes) + return nil +} + +func (w byte16Wrapper) UUIDValue() (UUID, error) { + return UUID{Bytes: [16]byte(w), Valid: true}, nil +} + +type byteSliceWrapper []byte + +func (w byteSliceWrapper) SkipUnderlyingTypePlan() {} + +func (w *byteSliceWrapper) ScanText(v Text) error { + if !v.Valid { + *w = nil + return nil + } + + *w = byteSliceWrapper(v.String) + return nil +} + +func (w byteSliceWrapper) TextValue() (Text, error) { + if w == nil { + return Text{}, nil + } + + return Text{String: string(w), Valid: true}, nil +} + +func (w *byteSliceWrapper) ScanUUID(v UUID) error { + if !v.Valid { + *w = nil + return nil + } + *w = make(byteSliceWrapper, 16) + copy(*w, v.Bytes[:]) + return nil +} + +func (w byteSliceWrapper) UUIDValue() (UUID, error) { + if w == nil { + return UUID{}, nil + } + + uuid := UUID{Valid: true} + copy(uuid.Bytes[:], w) + return uuid, nil +} + +// structWrapper implements CompositeIndexGetter for a struct. +type structWrapper struct { + s any + exportedFields []reflect.Value +} + +func (w structWrapper) IsNull() bool { + return w.s == nil +} + +func (w structWrapper) Index(i int) any { + if i >= len(w.exportedFields) { + return fmt.Errorf("%#v only has %d public fields - %d is out of bounds", w.s, len(w.exportedFields), i) + } + + return w.exportedFields[i].Interface() +} + +// ptrStructWrapper implements CompositeIndexScanner for a pointer to a struct. +type ptrStructWrapper struct { + s any + exportedFields []reflect.Value +} + +func (w *ptrStructWrapper) ScanNull() error { + return fmt.Errorf("cannot scan NULL into %#v", w.s) +} + +func (w *ptrStructWrapper) ScanIndex(i int) any { + if i >= len(w.exportedFields) { + return fmt.Errorf("%#v only has %d public fields - %d is out of bounds", w.s, len(w.exportedFields), i) + } + + return w.exportedFields[i].Addr().Interface() +} + +type anySliceArrayReflect struct { + slice reflect.Value +} + +func (a anySliceArrayReflect) Dimensions() []ArrayDimension { + if a.slice.IsNil() { + return nil + } + + return []ArrayDimension{{Length: int32(a.slice.Len()), LowerBound: 1}} +} + +func (a anySliceArrayReflect) Index(i int) any { + return a.slice.Index(i).Interface() +} + +func (a anySliceArrayReflect) IndexType() any { + return reflect.New(a.slice.Type().Elem()).Elem().Interface() +} + +func (a *anySliceArrayReflect) SetDimensions(dimensions []ArrayDimension) error { + sliceType := a.slice.Type() + + if dimensions == nil { + a.slice.Set(reflect.Zero(sliceType)) + return nil + } + + elementCount := cardinality(dimensions) + slice := reflect.MakeSlice(sliceType, elementCount, elementCount) + a.slice.Set(slice) + return nil +} + +func (a *anySliceArrayReflect) ScanIndex(i int) any { + return a.slice.Index(i).Addr().Interface() +} + +func (a *anySliceArrayReflect) ScanIndexType() any { + return reflect.New(a.slice.Type().Elem()).Interface() +} + +type anyMultiDimSliceArray struct { + slice reflect.Value + dims []ArrayDimension +} + +func (a *anyMultiDimSliceArray) Dimensions() []ArrayDimension { + if a.slice.IsNil() { + return nil + } + + s := a.slice + for { + a.dims = append(a.dims, ArrayDimension{Length: int32(s.Len()), LowerBound: 1}) + if s.Len() > 0 { + s = s.Index(0) + } else { + break + } + if s.Type().Kind() == reflect.Slice { + } else { + break + } + } + + return a.dims +} + +func (a *anyMultiDimSliceArray) Index(i int) any { + if len(a.dims) == 1 { + return a.slice.Index(i).Interface() + } + + indexes := make([]int, len(a.dims)) + for j := len(a.dims) - 1; j >= 0; j-- { + dimLen := int(a.dims[j].Length) + indexes[j] = i % dimLen + i = i / dimLen + } + + v := a.slice + for _, si := range indexes { + v = v.Index(si) + } + + return v.Interface() +} + +func (a *anyMultiDimSliceArray) IndexType() any { + lowestSliceType := a.slice.Type() + for ; lowestSliceType.Elem().Kind() == reflect.Slice; lowestSliceType = lowestSliceType.Elem() { + } + return reflect.New(lowestSliceType.Elem()).Elem().Interface() +} + +func (a *anyMultiDimSliceArray) SetDimensions(dimensions []ArrayDimension) error { + sliceType := a.slice.Type() + + if dimensions == nil { + a.slice.Set(reflect.Zero(sliceType)) + return nil + } + + switch len(dimensions) { + case 0: + // Empty, but non-nil array + slice := reflect.MakeSlice(sliceType, 0, 0) + a.slice.Set(slice) + return nil + case 1: + elementCount := cardinality(dimensions) + slice := reflect.MakeSlice(sliceType, elementCount, elementCount) + a.slice.Set(slice) + return nil + default: + sliceDimensionCount := 1 + lowestSliceType := sliceType + for ; lowestSliceType.Elem().Kind() == reflect.Slice; lowestSliceType = lowestSliceType.Elem() { + sliceDimensionCount++ + } + + if sliceDimensionCount != len(dimensions) { + return fmt.Errorf("PostgreSQL array has %d dimensions but slice has %d dimensions", len(dimensions), sliceDimensionCount) + } + + elementCount := cardinality(dimensions) + flatSlice := reflect.MakeSlice(lowestSliceType, elementCount, elementCount) + + multiDimSlice := a.makeMultidimensionalSlice(sliceType, dimensions, flatSlice, 0) + a.slice.Set(multiDimSlice) + + // Now that a.slice is a multi-dimensional slice with the underlying data pointed at flatSlice change a.slice to + // flatSlice so ScanIndex only has to handle simple one dimensional slices. + a.slice = flatSlice + + return nil + } + +} + +func (a *anyMultiDimSliceArray) makeMultidimensionalSlice(sliceType reflect.Type, dimensions []ArrayDimension, flatSlice reflect.Value, flatSliceIdx int) reflect.Value { + if len(dimensions) == 1 { + endIdx := flatSliceIdx + int(dimensions[0].Length) + return flatSlice.Slice3(flatSliceIdx, endIdx, endIdx) + } + + sliceLen := int(dimensions[0].Length) + slice := reflect.MakeSlice(sliceType, sliceLen, sliceLen) + for i := 0; i < sliceLen; i++ { + subSlice := a.makeMultidimensionalSlice(sliceType.Elem(), dimensions[1:], flatSlice, flatSliceIdx+(i*int(dimensions[1].Length))) + slice.Index(i).Set(subSlice) + } + + return slice +} + +func (a *anyMultiDimSliceArray) ScanIndex(i int) any { + return a.slice.Index(i).Addr().Interface() +} + +func (a *anyMultiDimSliceArray) ScanIndexType() any { + lowestSliceType := a.slice.Type() + for ; lowestSliceType.Elem().Kind() == reflect.Slice; lowestSliceType = lowestSliceType.Elem() { + } + return reflect.New(lowestSliceType.Elem()).Interface() +} diff --git a/pgtype/bytea.go b/pgtype/bytea.go new file mode 100644 index 00000000..2e067672 --- /dev/null +++ b/pgtype/bytea.go @@ -0,0 +1,255 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/hex" + "fmt" +) + +type BytesScanner interface { + // ScanBytes receives a byte slice of driver memory that is only valid until the next database method call. + ScanBytes(v []byte) error +} + +type BytesValuer interface { + // BytesValue returns a byte slice of the byte data. The caller must not change the returned slice. + BytesValue() ([]byte, error) +} + +// DriverBytes is a byte slice that holds a reference to memory owned by the driver. It is only valid from the time it +// is scanned until Rows.Next or Rows.Close is called. It is never safe to use DriverBytes with QueryRow as Row.Scan +// internally calls Rows.Close before returning. +type DriverBytes []byte + +func (b *DriverBytes) ScanBytes(v []byte) error { + *b = v + return nil +} + +// PreallocBytes is a byte slice of preallocated memory that scanned bytes will be copied to. If it is too small a new +// slice will be allocated. +type PreallocBytes []byte + +func (b *PreallocBytes) ScanBytes(v []byte) error { + if v == nil { + *b = nil + return nil + } + + if len(v) <= len(*b) { + *b = (*b)[:len(v)] + } else { + *b = make(PreallocBytes, len(v)) + } + copy(*b, v) + return nil +} + +// UndecodedBytes can be used as a scan target to get the raw bytes from PostgreSQL without any decoding. +type UndecodedBytes []byte + +type scanPlanAnyToUndecodedBytes struct{} + +func (scanPlanAnyToUndecodedBytes) Scan(src []byte, dst any) error { + dstBuf := dst.(*UndecodedBytes) + if src == nil { + *dstBuf = nil + return nil + } + + *dstBuf = make([]byte, len(src)) + copy(*dstBuf, src) + return nil +} + +type ByteaCodec struct{} + +func (ByteaCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (ByteaCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (ByteaCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch format { + case BinaryFormatCode: + switch value.(type) { + case []byte: + return encodePlanBytesCodecBinaryBytes{} + case BytesValuer: + return encodePlanBytesCodecBinaryBytesValuer{} + } + case TextFormatCode: + switch value.(type) { + case []byte: + return encodePlanBytesCodecTextBytes{} + case BytesValuer: + return encodePlanBytesCodecTextBytesValuer{} + } + } + + return nil +} + +type encodePlanBytesCodecBinaryBytes struct{} + +func (encodePlanBytesCodecBinaryBytes) Encode(value any, buf []byte) (newBuf []byte, err error) { + b := value.([]byte) + if b == nil { + return nil, nil + } + + return append(buf, b...), nil +} + +type encodePlanBytesCodecBinaryBytesValuer struct{} + +func (encodePlanBytesCodecBinaryBytesValuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + b, err := value.(BytesValuer).BytesValue() + if err != nil { + return nil, err + } + if b == nil { + return nil, nil + } + + return append(buf, b...), nil +} + +type encodePlanBytesCodecTextBytes struct{} + +func (encodePlanBytesCodecTextBytes) Encode(value any, buf []byte) (newBuf []byte, err error) { + b := value.([]byte) + if b == nil { + return nil, nil + } + + buf = append(buf, `\x`...) + buf = append(buf, hex.EncodeToString(b)...) + return buf, nil +} + +type encodePlanBytesCodecTextBytesValuer struct{} + +func (encodePlanBytesCodecTextBytesValuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + b, err := value.(BytesValuer).BytesValue() + if err != nil { + return nil, err + } + if b == nil { + return nil, nil + } + + buf = append(buf, `\x`...) + buf = append(buf, hex.EncodeToString(b)...) + return buf, nil +} + +func (ByteaCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case *[]byte: + return scanPlanBinaryBytesToBytes{} + case BytesScanner: + return scanPlanBinaryBytesToBytesScanner{} + } + case TextFormatCode: + switch target.(type) { + case *[]byte: + return scanPlanTextByteaToBytes{} + case BytesScanner: + return scanPlanTextByteaToBytesScanner{} + } + } + + return nil +} + +type scanPlanBinaryBytesToBytes struct{} + +func (scanPlanBinaryBytesToBytes) Scan(src []byte, dst any) error { + dstBuf := dst.(*[]byte) + if src == nil { + *dstBuf = nil + return nil + } + + *dstBuf = make([]byte, len(src)) + copy(*dstBuf, src) + return nil +} + +type scanPlanBinaryBytesToBytesScanner struct{} + +func (scanPlanBinaryBytesToBytesScanner) Scan(src []byte, dst any) error { + scanner := (dst).(BytesScanner) + return scanner.ScanBytes(src) +} + +type scanPlanTextByteaToBytes struct{} + +func (scanPlanTextByteaToBytes) Scan(src []byte, dst any) error { + dstBuf := dst.(*[]byte) + if src == nil { + *dstBuf = nil + return nil + } + + buf, err := decodeHexBytea(src) + if err != nil { + return err + } + *dstBuf = buf + + return nil +} + +type scanPlanTextByteaToBytesScanner struct{} + +func (scanPlanTextByteaToBytesScanner) Scan(src []byte, dst any) error { + scanner := (dst).(BytesScanner) + buf, err := decodeHexBytea(src) + if err != nil { + return err + } + return scanner.ScanBytes(buf) +} + +func decodeHexBytea(src []byte) ([]byte, error) { + if src == nil { + return nil, nil + } + + if len(src) < 2 || src[0] != '\\' || src[1] != 'x' { + return nil, fmt.Errorf("invalid hex format") + } + + buf := make([]byte, (len(src)-2)/2) + _, err := hex.Decode(buf, src[2:]) + if err != nil { + return nil, err + } + + return buf, nil +} + +func (c ByteaCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) +} + +func (c ByteaCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var buf []byte + err := codecScan(c, m, oid, format, src, &buf) + if err != nil { + return nil, err + } + return buf, nil +} diff --git a/pgtype/bytea_test.go b/pgtype/bytea_test.go new file mode 100644 index 00000000..a0d27369 --- /dev/null +++ b/pgtype/bytea_test.go @@ -0,0 +1,116 @@ +package pgtype_test + +import ( + "bytes" + "context" + "testing" + + pgx "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/require" +) + +func isExpectedEqBytes(a any) func(any) bool { + return func(v any) bool { + ab := a.([]byte) + vb := v.([]byte) + + if (ab == nil) != (vb == nil) { + return false + } + + if ab == nil { + return true + } + + return bytes.Compare(ab, vb) == 0 + } +} + +func TestByteaCodec(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "bytea", []pgxtest.ValueRoundTripTest{ + {[]byte{1, 2, 3}, new([]byte), isExpectedEqBytes([]byte{1, 2, 3})}, + {[]byte{}, new([]byte), isExpectedEqBytes([]byte{})}, + {[]byte(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, + {nil, new([]byte), isExpectedEqBytes([]byte(nil))}, + }) +} + +func TestDriverBytesQueryRow(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var buf []byte + err := conn.QueryRow(ctx, `select $1::bytea`, []byte{1, 2}).Scan((*pgtype.DriverBytes)(&buf)) + require.EqualError(t, err, "cannot scan into *pgtype.DriverBytes from QueryRow") + }) +} + +func TestDriverBytes(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + argBuf := make([]byte, 128) + for i := range argBuf { + argBuf[i] = byte(i) + } + + rows, err := conn.Query(ctx, `select $1::bytea from generate_series(1, 1000)`, argBuf) + require.NoError(t, err) + defer rows.Close() + + rowCount := 0 + resultBuf := argBuf + detectedResultMutation := false + for rows.Next() { + rowCount++ + + // At some point the buffer should be reused and change. + if bytes.Compare(argBuf, resultBuf) != 0 { + detectedResultMutation = true + } + + err = rows.Scan((*pgtype.DriverBytes)(&resultBuf)) + require.NoError(t, err) + + require.Len(t, resultBuf, len(argBuf)) + require.Equal(t, resultBuf, argBuf) + require.Equalf(t, cap(resultBuf), len(resultBuf), "cap(resultBuf) is larger than len(resultBuf)") + } + + require.True(t, detectedResultMutation) + + err = rows.Err() + require.NoError(t, err) + }) +} + +func TestPreallocBytes(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + origBuf := []byte{5, 6, 7, 8} + buf := origBuf + err := conn.QueryRow(ctx, `select $1::bytea`, []byte{1, 2}).Scan((*pgtype.PreallocBytes)(&buf)) + require.NoError(t, err) + + require.Len(t, buf, 2) + require.Equal(t, 4, cap(buf)) + require.Equal(t, buf, []byte{1, 2}) + + require.Equal(t, []byte{1, 2, 7, 8}, origBuf) + + err = conn.QueryRow(ctx, `select $1::bytea`, []byte{3, 4, 5, 6, 7}).Scan((*pgtype.PreallocBytes)(&buf)) + require.NoError(t, err) + require.Len(t, buf, 5) + require.Equal(t, 5, cap(buf)) + + require.Equal(t, []byte{1, 2, 7, 8}, origBuf) + }) +} + +func TestUndecodedBytes(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var buf []byte + err := conn.QueryRow(ctx, `select 1::int4`).Scan((*pgtype.UndecodedBytes)(&buf)) + require.NoError(t, err) + + require.Len(t, buf, 4) + require.Equal(t, buf, []byte{0, 0, 0, 1}) + }) +} diff --git a/pgtype/circle.go b/pgtype/circle.go new file mode 100644 index 00000000..e8f118cc --- /dev/null +++ b/pgtype/circle.go @@ -0,0 +1,222 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + "strings" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type CircleScanner interface { + ScanCircle(v Circle) error +} + +type CircleValuer interface { + CircleValue() (Circle, error) +} + +type Circle struct { + P Vec2 + R float64 + Valid bool +} + +func (c *Circle) ScanCircle(v Circle) error { + *c = v + return nil +} + +func (c Circle) CircleValue() (Circle, error) { + return c, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Circle) Scan(src any) error { + if src == nil { + *dst = Circle{} + return nil + } + + switch src := src.(type) { + case string: + return scanPlanTextAnyToCircleScanner{}.Scan([]byte(src), dst) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Circle) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + + buf, err := CircleCodec{}.PlanEncode(nil, 0, TextFormatCode, src).Encode(src, nil) + if err != nil { + return nil, err + } + return string(buf), err +} + +type CircleCodec struct{} + +func (CircleCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (CircleCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (CircleCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(CircleValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return encodePlanCircleCodecBinary{} + case TextFormatCode: + return encodePlanCircleCodecText{} + } + + return nil +} + +type encodePlanCircleCodecBinary struct{} + +func (encodePlanCircleCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + circle, err := value.(CircleValuer).CircleValue() + if err != nil { + return nil, err + } + + if !circle.Valid { + return nil, nil + } + + buf = pgio.AppendUint64(buf, math.Float64bits(circle.P.X)) + buf = pgio.AppendUint64(buf, math.Float64bits(circle.P.Y)) + buf = pgio.AppendUint64(buf, math.Float64bits(circle.R)) + return buf, nil +} + +type encodePlanCircleCodecText struct{} + +func (encodePlanCircleCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { + circle, err := value.(CircleValuer).CircleValue() + if err != nil { + return nil, err + } + + if !circle.Valid { + return nil, nil + } + + buf = append(buf, fmt.Sprintf(`<(%s,%s),%s>`, + strconv.FormatFloat(circle.P.X, 'f', -1, 64), + strconv.FormatFloat(circle.P.Y, 'f', -1, 64), + strconv.FormatFloat(circle.R, 'f', -1, 64), + )...) + return buf, nil +} + +func (CircleCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case CircleScanner: + return scanPlanBinaryCircleToCircleScanner{} + } + case TextFormatCode: + switch target.(type) { + case CircleScanner: + return scanPlanTextAnyToCircleScanner{} + } + } + + return nil +} + +func (c CircleCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) +} + +func (c CircleCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var circle Circle + err := codecScan(c, m, oid, format, src, &circle) + if err != nil { + return nil, err + } + return circle, nil +} + +type scanPlanBinaryCircleToCircleScanner struct{} + +func (scanPlanBinaryCircleToCircleScanner) Scan(src []byte, dst any) error { + scanner := (dst).(CircleScanner) + + if src == nil { + return scanner.ScanCircle(Circle{}) + } + + if len(src) != 24 { + return fmt.Errorf("invalid length for Circle: %v", len(src)) + } + + x := binary.BigEndian.Uint64(src) + y := binary.BigEndian.Uint64(src[8:]) + r := binary.BigEndian.Uint64(src[16:]) + + return scanner.ScanCircle(Circle{ + P: Vec2{math.Float64frombits(x), math.Float64frombits(y)}, + R: math.Float64frombits(r), + Valid: true, + }) +} + +type scanPlanTextAnyToCircleScanner struct{} + +func (scanPlanTextAnyToCircleScanner) Scan(src []byte, dst any) error { + scanner := (dst).(CircleScanner) + + if src == nil { + return scanner.ScanCircle(Circle{}) + } + + if len(src) < 9 { + return fmt.Errorf("invalid length for Circle: %v", len(src)) + } + + str := string(src[2:]) + end := strings.IndexByte(str, ',') + x, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1:] + end = strings.IndexByte(str, ')') + + y, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+2 : len(str)-1] + + r, err := strconv.ParseFloat(str, 64) + if err != nil { + return err + } + + return scanner.ScanCircle(Circle{P: Vec2{x, y}, R: r, Valid: true}) +} diff --git a/pgtype/circle_test.go b/pgtype/circle_test.go new file mode 100644 index 00000000..7b6db777 --- /dev/null +++ b/pgtype/circle_test.go @@ -0,0 +1,28 @@ +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" +) + +func TestCircleTranscode(t *testing.T) { + skipCockroachDB(t, "Server does not support box type") + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "circle", []pgxtest.ValueRoundTripTest{ + { + pgtype.Circle{P: pgtype.Vec2{1.234, 5.67890123}, R: 3.5, Valid: true}, + new(pgtype.Circle), + isExpectedEq(pgtype.Circle{P: pgtype.Vec2{1.234, 5.67890123}, R: 3.5, Valid: true}), + }, + { + pgtype.Circle{P: pgtype.Vec2{1.234, 5.67890123}, R: 3.5, Valid: true}, + new(pgtype.Circle), + isExpectedEq(pgtype.Circle{P: pgtype.Vec2{1.234, 5.67890123}, R: 3.5, Valid: true}), + }, + {pgtype.Circle{}, new(pgtype.Circle), isExpectedEq(pgtype.Circle{})}, + {nil, new(pgtype.Circle), isExpectedEq(pgtype.Circle{})}, + }) +} diff --git a/pgtype/composite.go b/pgtype/composite.go new file mode 100644 index 00000000..fb372325 --- /dev/null +++ b/pgtype/composite.go @@ -0,0 +1,602 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "errors" + "fmt" + "strings" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +// CompositeIndexGetter is a type accessed by index that can be converted into a PostgreSQL composite. +type CompositeIndexGetter interface { + // IsNull returns true if the value is SQL NULL. + IsNull() bool + + // Index returns the element at i. + Index(i int) any +} + +// CompositeIndexScanner is a type accessed by index that can be scanned from a PostgreSQL composite. +type CompositeIndexScanner interface { + // ScanNull sets the value to SQL NULL. + ScanNull() error + + // ScanIndex returns a value usable as a scan target for i. + ScanIndex(i int) any +} + +type CompositeCodecField struct { + Name string + Type *Type +} + +type CompositeCodec struct { + Fields []CompositeCodecField +} + +func (c *CompositeCodec) FormatSupported(format int16) bool { + for _, f := range c.Fields { + if !f.Type.Codec.FormatSupported(format) { + return false + } + } + + return true +} + +func (c *CompositeCodec) PreferredFormat() int16 { + if c.FormatSupported(BinaryFormatCode) { + return BinaryFormatCode + } + return TextFormatCode +} + +func (c *CompositeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(CompositeIndexGetter); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return &encodePlanCompositeCodecCompositeIndexGetterToBinary{cc: c, m: m} + case TextFormatCode: + return &encodePlanCompositeCodecCompositeIndexGetterToText{cc: c, m: m} + } + + return nil +} + +type encodePlanCompositeCodecCompositeIndexGetterToBinary struct { + cc *CompositeCodec + m *Map +} + +func (plan *encodePlanCompositeCodecCompositeIndexGetterToBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + getter := value.(CompositeIndexGetter) + + if getter.IsNull() { + return nil, nil + } + + builder := NewCompositeBinaryBuilder(plan.m, buf) + for i, field := range plan.cc.Fields { + builder.AppendValue(field.Type.OID, getter.Index(i)) + } + + return builder.Finish() +} + +type encodePlanCompositeCodecCompositeIndexGetterToText struct { + cc *CompositeCodec + m *Map +} + +func (plan *encodePlanCompositeCodecCompositeIndexGetterToText) Encode(value any, buf []byte) (newBuf []byte, err error) { + getter := value.(CompositeIndexGetter) + + if getter.IsNull() { + return nil, nil + } + + b := NewCompositeTextBuilder(plan.m, buf) + for i, field := range plan.cc.Fields { + b.AppendValue(field.Type.OID, getter.Index(i)) + } + + return b.Finish() +} + +func (c *CompositeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case CompositeIndexScanner: + return &scanPlanBinaryCompositeToCompositeIndexScanner{cc: c, m: m} + } + case TextFormatCode: + switch target.(type) { + case CompositeIndexScanner: + return &scanPlanTextCompositeToCompositeIndexScanner{cc: c, m: m} + } + } + + return nil +} + +type scanPlanBinaryCompositeToCompositeIndexScanner struct { + cc *CompositeCodec + m *Map +} + +func (plan *scanPlanBinaryCompositeToCompositeIndexScanner) Scan(src []byte, target any) error { + targetScanner := (target).(CompositeIndexScanner) + + if src == nil { + return targetScanner.ScanNull() + } + + scanner := NewCompositeBinaryScanner(plan.m, src) + for i, field := range plan.cc.Fields { + if scanner.Next() { + fieldTarget := targetScanner.ScanIndex(i) + if fieldTarget != nil { + fieldPlan := plan.m.PlanScan(field.Type.OID, BinaryFormatCode, fieldTarget) + if fieldPlan == nil { + return fmt.Errorf("unable to encode %v into OID %d in binary format", field, field.Type.OID) + } + + err := fieldPlan.Scan(scanner.Bytes(), fieldTarget) + if err != nil { + return err + } + } + } else { + return errors.New("read past end of composite") + } + } + + if err := scanner.Err(); err != nil { + return err + } + + return nil +} + +type scanPlanTextCompositeToCompositeIndexScanner struct { + cc *CompositeCodec + m *Map +} + +func (plan *scanPlanTextCompositeToCompositeIndexScanner) Scan(src []byte, target any) error { + targetScanner := (target).(CompositeIndexScanner) + + if src == nil { + return targetScanner.ScanNull() + } + + scanner := NewCompositeTextScanner(plan.m, src) + for i, field := range plan.cc.Fields { + if scanner.Next() { + fieldTarget := targetScanner.ScanIndex(i) + if fieldTarget != nil { + fieldPlan := plan.m.PlanScan(field.Type.OID, TextFormatCode, fieldTarget) + if fieldPlan == nil { + return fmt.Errorf("unable to encode %v into OID %d in text format", field, field.Type.OID) + } + + err := fieldPlan.Scan(scanner.Bytes(), fieldTarget) + if err != nil { + return err + } + } + } else { + return errors.New("read past end of composite") + } + } + + if err := scanner.Err(); err != nil { + return err + } + + return nil +} + +func (c *CompositeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + switch format { + case TextFormatCode: + return string(src), nil + case BinaryFormatCode: + buf := make([]byte, len(src)) + copy(buf, src) + return buf, nil + default: + return nil, fmt.Errorf("unknown format code %d", format) + } +} + +func (c *CompositeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + switch format { + case TextFormatCode: + scanner := NewCompositeTextScanner(m, src) + values := make(map[string]any, len(c.Fields)) + for i := 0; scanner.Next() && i < len(c.Fields); i++ { + var v any + fieldPlan := m.PlanScan(c.Fields[i].Type.OID, TextFormatCode, &v) + if fieldPlan == nil { + return nil, fmt.Errorf("unable to scan OID %d in text format into %v", c.Fields[i].Type.OID, v) + } + + err := fieldPlan.Scan(scanner.Bytes(), &v) + if err != nil { + return nil, err + } + + values[c.Fields[i].Name] = v + } + + if err := scanner.Err(); err != nil { + return nil, err + } + + return values, nil + case BinaryFormatCode: + scanner := NewCompositeBinaryScanner(m, src) + values := make(map[string]any, len(c.Fields)) + for i := 0; scanner.Next() && i < len(c.Fields); i++ { + var v any + fieldPlan := m.PlanScan(scanner.OID(), BinaryFormatCode, &v) + if fieldPlan == nil { + return nil, fmt.Errorf("unable to scan OID %d in binary format into %v", scanner.OID(), v) + } + + err := fieldPlan.Scan(scanner.Bytes(), &v) + if err != nil { + return nil, err + } + + values[c.Fields[i].Name] = v + } + + if err := scanner.Err(); err != nil { + return nil, err + } + + return values, nil + default: + return nil, fmt.Errorf("unknown format code %d", format) + } + +} + +type CompositeBinaryScanner struct { + m *Map + rp int + src []byte + + fieldCount int32 + fieldBytes []byte + fieldOID uint32 + err error +} + +// NewCompositeBinaryScanner a scanner over a binary encoded composite balue. +func NewCompositeBinaryScanner(m *Map, src []byte) *CompositeBinaryScanner { + rp := 0 + if len(src[rp:]) < 4 { + return &CompositeBinaryScanner{err: fmt.Errorf("Record incomplete %v", src)} + } + + fieldCount := int32(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + return &CompositeBinaryScanner{ + m: m, + rp: rp, + src: src, + fieldCount: fieldCount, + } +} + +// Next advances the scanner to the next field. It returns false after the last field is read or an error occurs. After +// Next returns false, the Err method can be called to check if any errors occurred. +func (cfs *CompositeBinaryScanner) Next() bool { + if cfs.err != nil { + return false + } + + if cfs.rp == len(cfs.src) { + return false + } + + if len(cfs.src[cfs.rp:]) < 8 { + cfs.err = fmt.Errorf("Record incomplete %v", cfs.src) + return false + } + cfs.fieldOID = binary.BigEndian.Uint32(cfs.src[cfs.rp:]) + cfs.rp += 4 + + fieldLen := int(int32(binary.BigEndian.Uint32(cfs.src[cfs.rp:]))) + cfs.rp += 4 + + if fieldLen >= 0 { + if len(cfs.src[cfs.rp:]) < fieldLen { + cfs.err = fmt.Errorf("Record incomplete rp=%d src=%v", cfs.rp, cfs.src) + return false + } + cfs.fieldBytes = cfs.src[cfs.rp : cfs.rp+fieldLen] + cfs.rp += fieldLen + } else { + cfs.fieldBytes = nil + } + + return true +} + +func (cfs *CompositeBinaryScanner) FieldCount() int { + return int(cfs.fieldCount) +} + +// Bytes returns the bytes of the field most recently read by Scan(). +func (cfs *CompositeBinaryScanner) Bytes() []byte { + return cfs.fieldBytes +} + +// OID returns the OID of the field most recently read by Scan(). +func (cfs *CompositeBinaryScanner) OID() uint32 { + return cfs.fieldOID +} + +// Err returns any error encountered by the scanner. +func (cfs *CompositeBinaryScanner) Err() error { + return cfs.err +} + +type CompositeTextScanner struct { + m *Map + rp int + src []byte + + fieldBytes []byte + err error +} + +// NewCompositeTextScanner a scanner over a text encoded composite value. +func NewCompositeTextScanner(m *Map, src []byte) *CompositeTextScanner { + if len(src) < 2 { + return &CompositeTextScanner{err: fmt.Errorf("Record incomplete %v", src)} + } + + if src[0] != '(' { + return &CompositeTextScanner{err: fmt.Errorf("composite text format must start with '('")} + } + + if src[len(src)-1] != ')' { + return &CompositeTextScanner{err: fmt.Errorf("composite text format must end with ')'")} + } + + return &CompositeTextScanner{ + m: m, + rp: 1, + src: src, + } +} + +// Next advances the scanner to the next field. It returns false after the last field is read or an error occurs. After +// Next returns false, the Err method can be called to check if any errors occurred. +func (cfs *CompositeTextScanner) Next() bool { + if cfs.err != nil { + return false + } + + if cfs.rp == len(cfs.src) { + return false + } + + switch cfs.src[cfs.rp] { + case ',', ')': // null + cfs.rp++ + cfs.fieldBytes = nil + return true + case '"': // quoted value + cfs.rp++ + cfs.fieldBytes = make([]byte, 0, 16) + for { + ch := cfs.src[cfs.rp] + + if ch == '"' { + cfs.rp++ + if cfs.src[cfs.rp] == '"' { + cfs.fieldBytes = append(cfs.fieldBytes, '"') + cfs.rp++ + } else { + break + } + } else if ch == '\\' { + cfs.rp++ + cfs.fieldBytes = append(cfs.fieldBytes, cfs.src[cfs.rp]) + cfs.rp++ + } else { + cfs.fieldBytes = append(cfs.fieldBytes, ch) + cfs.rp++ + } + } + cfs.rp++ + return true + default: // unquoted value + start := cfs.rp + for { + ch := cfs.src[cfs.rp] + if ch == ',' || ch == ')' { + break + } + cfs.rp++ + } + cfs.fieldBytes = cfs.src[start:cfs.rp] + cfs.rp++ + return true + } +} + +// Bytes returns the bytes of the field most recently read by Scan(). +func (cfs *CompositeTextScanner) Bytes() []byte { + return cfs.fieldBytes +} + +// Err returns any error encountered by the scanner. +func (cfs *CompositeTextScanner) Err() error { + return cfs.err +} + +type CompositeBinaryBuilder struct { + m *Map + buf []byte + startIdx int + fieldCount uint32 + err error +} + +func NewCompositeBinaryBuilder(m *Map, buf []byte) *CompositeBinaryBuilder { + startIdx := len(buf) + buf = append(buf, 0, 0, 0, 0) // allocate room for number of fields + return &CompositeBinaryBuilder{m: m, buf: buf, startIdx: startIdx} +} + +func (b *CompositeBinaryBuilder) AppendValue(oid uint32, field any) { + if b.err != nil { + return + } + + if field == nil { + b.buf = pgio.AppendUint32(b.buf, oid) + b.buf = pgio.AppendInt32(b.buf, -1) + b.fieldCount++ + return + } + + plan := b.m.PlanEncode(oid, BinaryFormatCode, field) + if plan == nil { + b.err = fmt.Errorf("unable to encode %v into OID %d in binary format", field, oid) + return + } + + b.buf = pgio.AppendUint32(b.buf, oid) + lengthPos := len(b.buf) + b.buf = pgio.AppendInt32(b.buf, -1) + fieldBuf, err := plan.Encode(field, b.buf) + if err != nil { + b.err = err + return + } + if fieldBuf != nil { + binary.BigEndian.PutUint32(fieldBuf[lengthPos:], uint32(len(fieldBuf)-len(b.buf))) + b.buf = fieldBuf + } + + b.fieldCount++ +} + +func (b *CompositeBinaryBuilder) Finish() ([]byte, error) { + if b.err != nil { + return nil, b.err + } + + binary.BigEndian.PutUint32(b.buf[b.startIdx:], b.fieldCount) + return b.buf, nil +} + +type CompositeTextBuilder struct { + m *Map + buf []byte + startIdx int + fieldCount uint32 + err error + fieldBuf [32]byte +} + +func NewCompositeTextBuilder(m *Map, buf []byte) *CompositeTextBuilder { + buf = append(buf, '(') // allocate room for number of fields + return &CompositeTextBuilder{m: m, buf: buf} +} + +func (b *CompositeTextBuilder) AppendValue(oid uint32, field any) { + if b.err != nil { + return + } + + if field == nil { + b.buf = append(b.buf, ',') + return + } + + plan := b.m.PlanEncode(oid, TextFormatCode, field) + if plan == nil { + b.err = fmt.Errorf("unable to encode %v into OID %d in text format", field, oid) + return + } + + fieldBuf, err := plan.Encode(field, b.fieldBuf[0:0]) + if err != nil { + b.err = err + return + } + if fieldBuf != nil { + b.buf = append(b.buf, quoteCompositeFieldIfNeeded(string(fieldBuf))...) + } + + b.buf = append(b.buf, ',') +} + +func (b *CompositeTextBuilder) Finish() ([]byte, error) { + if b.err != nil { + return nil, b.err + } + + b.buf[len(b.buf)-1] = ')' + return b.buf, nil +} + +var quoteCompositeReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`) + +func quoteCompositeField(src string) string { + return `"` + quoteCompositeReplacer.Replace(src) + `"` +} + +func quoteCompositeFieldIfNeeded(src string) string { + if src == "" || src[0] == ' ' || src[len(src)-1] == ' ' || strings.ContainsAny(src, `(),"\`) { + return quoteCompositeField(src) + } + return src +} + +// CompositeFields represents the values of a composite value. It can be used as an encoding source or as a scan target. +// It cannot scan a NULL, but the composite fields can be NULL. +type CompositeFields []any + +func (cf CompositeFields) SkipUnderlyingTypePlan() {} + +func (cf CompositeFields) IsNull() bool { + return cf == nil +} + +func (cf CompositeFields) Index(i int) any { + return cf[i] +} + +func (cf CompositeFields) ScanNull() error { + return fmt.Errorf("cannot scan NULL into CompositeFields") +} + +func (cf CompositeFields) ScanIndex(i int) any { + return cf[i] +} diff --git a/pgtype/composite_test.go b/pgtype/composite_test.go new file mode 100644 index 00000000..a6fa8315 --- /dev/null +++ b/pgtype/composite_test.go @@ -0,0 +1,210 @@ +package pgtype_test + +import ( + "context" + "fmt" + "testing" + + pgx "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/stretchr/testify/require" +) + +func TestCompositeCodecTranscode(t *testing.T) { + skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)") + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + + _, err := conn.Exec(ctx, `drop type if exists ct_test; + +create type ct_test as ( + a text, + b int4 +);`) + require.NoError(t, err) + defer conn.Exec(ctx, "drop type ct_test") + + dt, err := conn.LoadType(ctx, "ct_test") + require.NoError(t, err) + conn.TypeMap().RegisterType(dt) + + formats := []struct { + name string + code int16 + }{ + {name: "TextFormat", code: pgx.TextFormatCode}, + {name: "BinaryFormat", code: pgx.BinaryFormatCode}, + } + + for _, format := range formats { + var a string + var b int32 + + err := conn.QueryRow(ctx, "select $1::ct_test", pgx.QueryResultFormats{format.code}, + pgtype.CompositeFields{"hi", int32(42)}, + ).Scan( + pgtype.CompositeFields{&a, &b}, + ) + require.NoErrorf(t, err, "%v", format.name) + require.EqualValuesf(t, "hi", a, "%v", format.name) + require.EqualValuesf(t, 42, b, "%v", format.name) + } + }) +} + +type point3d struct { + X, Y, Z float64 +} + +func (p point3d) IsNull() bool { + return false +} + +func (p point3d) Index(i int) any { + switch i { + case 0: + return p.X + case 1: + return p.Y + case 2: + return p.Z + default: + panic("invalid index") + } +} + +func (p *point3d) ScanNull() error { + return fmt.Errorf("cannot scan NULL into point3d") +} + +func (p *point3d) ScanIndex(i int) any { + switch i { + case 0: + return &p.X + case 1: + return &p.Y + case 2: + return &p.Z + default: + panic("invalid index") + } +} + +func TestCompositeCodecTranscodeStruct(t *testing.T) { + skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)") + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + + _, err := conn.Exec(ctx, `drop type if exists point3d; + +create type point3d as ( + x float8, + y float8, + z float8 +);`) + require.NoError(t, err) + defer conn.Exec(ctx, "drop type point3d") + + dt, err := conn.LoadType(ctx, "point3d") + require.NoError(t, err) + conn.TypeMap().RegisterType(dt) + + formats := []struct { + name string + code int16 + }{ + {name: "TextFormat", code: pgx.TextFormatCode}, + {name: "BinaryFormat", code: pgx.BinaryFormatCode}, + } + + for _, format := range formats { + input := point3d{X: 1, Y: 2, Z: 3} + var output point3d + err := conn.QueryRow(ctx, "select $1::point3d", pgx.QueryResultFormats{format.code}, input).Scan(&output) + require.NoErrorf(t, err, "%v", format.name) + require.Equalf(t, input, output, "%v", format.name) + } + }) +} + +func TestCompositeCodecTranscodeStructWrapper(t *testing.T) { + skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)") + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + + _, err := conn.Exec(ctx, `drop type if exists point3d; + +create type point3d as ( + x float8, + y float8, + z float8 +);`) + require.NoError(t, err) + defer conn.Exec(ctx, "drop type point3d") + + dt, err := conn.LoadType(ctx, "point3d") + require.NoError(t, err) + conn.TypeMap().RegisterType(dt) + + formats := []struct { + name string + code int16 + }{ + {name: "TextFormat", code: pgx.TextFormatCode}, + {name: "BinaryFormat", code: pgx.BinaryFormatCode}, + } + + type anotherPoint struct { + X, Y, Z float64 + } + + for _, format := range formats { + input := anotherPoint{X: 1, Y: 2, Z: 3} + var output anotherPoint + err := conn.QueryRow(ctx, "select $1::point3d", pgx.QueryResultFormats{format.code}, input).Scan(&output) + require.NoErrorf(t, err, "%v", format.name) + require.Equalf(t, input, output, "%v", format.name) + } + }) +} + +func TestCompositeCodecDecodeValue(t *testing.T) { + skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)") + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + + _, err := conn.Exec(ctx, `drop type if exists point3d; + +create type point3d as ( + x float8, + y float8, + z float8 +);`) + require.NoError(t, err) + defer conn.Exec(ctx, "drop type point3d") + + dt, err := conn.LoadType(ctx, "point3d") + require.NoError(t, err) + conn.TypeMap().RegisterType(dt) + + formats := []struct { + name string + code int16 + }{ + {name: "TextFormat", code: pgx.TextFormatCode}, + {name: "BinaryFormat", code: pgx.BinaryFormatCode}, + } + + for _, format := range formats { + rows, err := conn.Query(ctx, "select '(1,2,3)'::point3d", pgx.QueryResultFormats{format.code}) + require.NoErrorf(t, err, "%v", format.name) + require.True(t, rows.Next()) + values, err := rows.Values() + require.NoErrorf(t, err, "%v", format.name) + require.Lenf(t, values, 1, "%v", format.name) + require.Equalf(t, map[string]any{"x": 1.0, "y": 2.0, "z": 3.0}, values[0], "%v", format.name) + require.False(t, rows.Next()) + require.NoErrorf(t, rows.Err(), "%v", format.name) + } + }) +} diff --git a/pgtype/convert.go b/pgtype/convert.go new file mode 100644 index 00000000..8a2afbe1 --- /dev/null +++ b/pgtype/convert.go @@ -0,0 +1,476 @@ +package pgtype + +import ( + "database/sql" + "fmt" + "math" + "reflect" + "time" +) + +const ( + maxUint = ^uint(0) + maxInt = int(maxUint >> 1) + minInt = -maxInt - 1 +) + +// underlyingNumberType gets the underlying type that can be converted to Int2, Int4, Int8, Float4, or Float8 +func underlyingNumberType(val any) (any, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.Int: + convVal := int(refVal.Int()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Int8: + convVal := int8(refVal.Int()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Int16: + convVal := int16(refVal.Int()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Int32: + convVal := int32(refVal.Int()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Int64: + convVal := int64(refVal.Int()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint: + convVal := uint(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint8: + convVal := uint8(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint16: + convVal := uint16(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint32: + convVal := uint32(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint64: + convVal := uint64(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Float32: + convVal := float32(refVal.Float()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Float64: + convVal := refVal.Float() + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.String: + convVal := refVal.String() + return convVal, reflect.TypeOf(convVal) != refVal.Type() + } + + return nil, false +} + +// underlyingBoolType gets the underlying type that can be converted to Bool +func underlyingBoolType(val any) (any, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.Bool: + convVal := refVal.Bool() + return convVal, reflect.TypeOf(convVal) != refVal.Type() + } + + return nil, false +} + +// underlyingBytesType gets the underlying type that can be converted to []byte +func underlyingBytesType(val any) (any, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.Slice: + if refVal.Type().Elem().Kind() == reflect.Uint8 { + convVal := refVal.Bytes() + return convVal, reflect.TypeOf(convVal) != refVal.Type() + } + } + + return nil, false +} + +// underlyingStringType gets the underlying type that can be converted to String +func underlyingStringType(val any) (any, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.String: + convVal := refVal.String() + return convVal, reflect.TypeOf(convVal) != refVal.Type() + } + + return nil, false +} + +// underlyingPtrType dereferences a pointer +func underlyingPtrType(val any) (any, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + } + + return nil, false +} + +// underlyingTimeType gets the underlying type that can be converted to time.Time +func underlyingTimeType(val any) (any, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + } + + timeType := reflect.TypeOf(time.Time{}) + if refVal.Type().ConvertibleTo(timeType) { + return refVal.Convert(timeType).Interface(), true + } + + return nil, false +} + +// underlyingUUIDType gets the underlying type that can be converted to [16]byte +func underlyingUUIDType(val any) (any, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return time.Time{}, false + } + convVal := refVal.Elem().Interface() + return convVal, true + } + + uuidType := reflect.TypeOf([16]byte{}) + if refVal.Type().ConvertibleTo(uuidType) { + return refVal.Convert(uuidType).Interface(), true + } + + return nil, false +} + +// underlyingSliceType gets the underlying slice type +func underlyingSliceType(val any) (any, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.Slice: + baseSliceType := reflect.SliceOf(refVal.Type().Elem()) + if refVal.Type().ConvertibleTo(baseSliceType) { + convVal := refVal.Convert(baseSliceType) + return convVal.Interface(), reflect.TypeOf(convVal.Interface()) != refVal.Type() + } + } + + return nil, false +} + +func int64AssignTo(srcVal int64, srcValid bool, dst any) error { + if srcValid { + switch v := dst.(type) { + case *int: + if srcVal < int64(minInt) { + return fmt.Errorf("%d is less than minimum value for int", srcVal) + } else if srcVal > int64(maxInt) { + return fmt.Errorf("%d is greater than maximum value for int", srcVal) + } + *v = int(srcVal) + case *int8: + if srcVal < math.MinInt8 { + return fmt.Errorf("%d is less than minimum value for int8", srcVal) + } else if srcVal > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for int8", srcVal) + } + *v = int8(srcVal) + case *int16: + if srcVal < math.MinInt16 { + return fmt.Errorf("%d is less than minimum value for int16", srcVal) + } else if srcVal > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for int16", srcVal) + } + *v = int16(srcVal) + case *int32: + if srcVal < math.MinInt32 { + return fmt.Errorf("%d is less than minimum value for int32", srcVal) + } else if srcVal > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for int32", srcVal) + } + *v = int32(srcVal) + case *int64: + if srcVal < math.MinInt64 { + return fmt.Errorf("%d is less than minimum value for int64", srcVal) + } else if srcVal > math.MaxInt64 { + return fmt.Errorf("%d is greater than maximum value for int64", srcVal) + } + *v = int64(srcVal) + case *uint: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for uint", srcVal) + } else if uint64(srcVal) > uint64(maxUint) { + return fmt.Errorf("%d is greater than maximum value for uint", srcVal) + } + *v = uint(srcVal) + case *uint8: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for uint8", srcVal) + } else if srcVal > math.MaxUint8 { + return fmt.Errorf("%d is greater than maximum value for uint8", srcVal) + } + *v = uint8(srcVal) + case *uint16: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for uint32", srcVal) + } else if srcVal > math.MaxUint16 { + return fmt.Errorf("%d is greater than maximum value for uint16", srcVal) + } + *v = uint16(srcVal) + case *uint32: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for uint32", srcVal) + } else if srcVal > math.MaxUint32 { + return fmt.Errorf("%d is greater than maximum value for uint32", srcVal) + } + *v = uint32(srcVal) + case *uint64: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for uint64", srcVal) + } + *v = uint64(srcVal) + case sql.Scanner: + return v.Scan(srcVal) + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return int64AssignTo(srcVal, srcValid, el.Interface()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if el.OverflowInt(int64(srcVal)) { + return fmt.Errorf("cannot put %d into %T", srcVal, dst) + } + el.SetInt(int64(srcVal)) + return nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for %T", srcVal, dst) + } + if el.OverflowUint(uint64(srcVal)) { + return fmt.Errorf("cannot put %d into %T", srcVal, dst) + } + el.SetUint(uint64(srcVal)) + return nil + } + } + return fmt.Errorf("cannot assign %v into %T", srcVal, dst) + } + return nil + } + + // if dst is a pointer to pointer and srcStatus is not Valid, nil it out + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + if el.Kind() == reflect.Ptr { + el.Set(reflect.Zero(el.Type())) + return nil + } + } + + return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcValid, dst) +} + +func float64AssignTo(srcVal float64, srcValid bool, dst any) error { + if srcValid { + switch v := dst.(type) { + case *float32: + *v = float32(srcVal) + case *float64: + *v = srcVal + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a type alias of a float32 or 64, set dst val + case reflect.Float32, reflect.Float64: + el.SetFloat(srcVal) + return nil + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return float64AssignTo(srcVal, srcValid, el.Interface()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + i64 := int64(srcVal) + if float64(i64) == srcVal { + return int64AssignTo(i64, srcValid, dst) + } + } + } + return fmt.Errorf("cannot assign %v into %T", srcVal, dst) + } + return nil + } + + // if dst is a pointer to pointer and srcStatus is not Valid, nil it out + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + if el.Kind() == reflect.Ptr { + el.Set(reflect.Zero(el.Type())) + return nil + } + } + + return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcValid, dst) +} + +func NullAssignTo(dst any) error { + dstPtr := reflect.ValueOf(dst) + + // AssignTo dst must always be a pointer + if dstPtr.Kind() != reflect.Ptr { + return &nullAssignmentError{dst: dst} + } + + dstVal := dstPtr.Elem() + + switch dstVal.Kind() { + case reflect.Ptr, reflect.Slice, reflect.Map: + dstVal.Set(reflect.Zero(dstVal.Type())) + return nil + } + + return &nullAssignmentError{dst: dst} +} + +var kindTypes map[reflect.Kind]reflect.Type + +func toInterface(dst reflect.Value, t reflect.Type) (any, bool) { + nextDst := dst.Convert(t) + return nextDst.Interface(), dst.Type() != nextDst.Type() +} + +// GetAssignToDstType attempts to convert dst to something AssignTo can assign +// to. If dst is a pointer to pointer it allocates a value and returns the +// dereferences pointer. If dst is a named type such as *Foo where Foo is type +// Foo int16, it converts dst to *int16. +// +// GetAssignToDstType returns the converted dst and a bool representing if any +// change was made. +func GetAssignToDstType(dst any) (any, bool) { + dstPtr := reflect.ValueOf(dst) + + // AssignTo dst must always be a pointer + if dstPtr.Kind() != reflect.Ptr { + return nil, false + } + + dstVal := dstPtr.Elem() + + // if dst is a pointer to pointer, allocate space try again with the dereferenced pointer + if dstVal.Kind() == reflect.Ptr { + dstVal.Set(reflect.New(dstVal.Type().Elem())) + return dstVal.Interface(), true + } + + // if dst is pointer to a base type that has been renamed + if baseValType, ok := kindTypes[dstVal.Kind()]; ok { + return toInterface(dstPtr, reflect.PtrTo(baseValType)) + } + + if dstVal.Kind() == reflect.Slice { + if baseElemType, ok := kindTypes[dstVal.Type().Elem().Kind()]; ok { + return toInterface(dstPtr, reflect.PtrTo(reflect.SliceOf(baseElemType))) + } + } + + if dstVal.Kind() == reflect.Array { + if baseElemType, ok := kindTypes[dstVal.Type().Elem().Kind()]; ok { + return toInterface(dstPtr, reflect.PtrTo(reflect.ArrayOf(dstVal.Len(), baseElemType))) + } + } + + if dstVal.Kind() == reflect.Struct { + if dstVal.Type().NumField() == 1 && dstVal.Type().Field(0).Anonymous { + dstPtr = dstVal.Field(0).Addr() + nested := dstVal.Type().Field(0).Type + if nested.Kind() == reflect.Array { + if baseElemType, ok := kindTypes[nested.Elem().Kind()]; ok { + return toInterface(dstPtr, reflect.PtrTo(reflect.ArrayOf(nested.Len(), baseElemType))) + } + } + if _, ok := kindTypes[nested.Kind()]; ok && dstPtr.CanInterface() { + return dstPtr.Interface(), true + } + } + } + + return nil, false +} + +func init() { + kindTypes = map[reflect.Kind]reflect.Type{ + reflect.Bool: reflect.TypeOf(false), + reflect.Float32: reflect.TypeOf(float32(0)), + reflect.Float64: reflect.TypeOf(float64(0)), + reflect.Int: reflect.TypeOf(int(0)), + reflect.Int8: reflect.TypeOf(int8(0)), + reflect.Int16: reflect.TypeOf(int16(0)), + reflect.Int32: reflect.TypeOf(int32(0)), + reflect.Int64: reflect.TypeOf(int64(0)), + reflect.Uint: reflect.TypeOf(uint(0)), + reflect.Uint8: reflect.TypeOf(uint8(0)), + reflect.Uint16: reflect.TypeOf(uint16(0)), + reflect.Uint32: reflect.TypeOf(uint32(0)), + reflect.Uint64: reflect.TypeOf(uint64(0)), + reflect.String: reflect.TypeOf(""), + } +} diff --git a/pgtype/date.go b/pgtype/date.go new file mode 100644 index 00000000..78c5db92 --- /dev/null +++ b/pgtype/date.go @@ -0,0 +1,327 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "encoding/json" + "fmt" + "strconv" + "time" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type DateScanner interface { + ScanDate(v Date) error +} + +type DateValuer interface { + DateValue() (Date, error) +} + +type Date struct { + Time time.Time + InfinityModifier InfinityModifier + Valid bool +} + +func (d *Date) ScanDate(v Date) error { + *d = v + return nil +} + +func (d Date) DateValue() (Date, error) { + return d, nil +} + +const ( + negativeInfinityDayOffset = -2147483648 + infinityDayOffset = 2147483647 +) + +// Scan implements the database/sql Scanner interface. +func (dst *Date) Scan(src any) error { + if src == nil { + *dst = Date{} + return nil + } + + switch src := src.(type) { + case string: + return scanPlanTextAnyToDateScanner{}.Scan([]byte(src), dst) + case time.Time: + *dst = Date{Time: src, Valid: true} + return nil + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Date) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + + if src.InfinityModifier != Finite { + return src.InfinityModifier.String(), nil + } + return src.Time, nil +} + +func (src Date) MarshalJSON() ([]byte, error) { + if !src.Valid { + return []byte("null"), nil + } + + var s string + + switch src.InfinityModifier { + case Finite: + s = src.Time.Format("2006-01-02") + case Infinity: + s = "infinity" + case NegativeInfinity: + s = "-infinity" + } + + return json.Marshal(s) +} + +func (dst *Date) UnmarshalJSON(b []byte) error { + var s *string + err := json.Unmarshal(b, &s) + if err != nil { + return err + } + + if s == nil { + *dst = Date{} + return nil + } + + switch *s { + case "infinity": + *dst = Date{Valid: true, InfinityModifier: Infinity} + case "-infinity": + *dst = Date{Valid: true, InfinityModifier: -Infinity} + default: + t, err := time.ParseInLocation("2006-01-02", *s, time.UTC) + if err != nil { + return err + } + + *dst = Date{Time: t, Valid: true} + } + + return nil +} + +type DateCodec struct{} + +func (DateCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (DateCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (DateCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(DateValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return encodePlanDateCodecBinary{} + case TextFormatCode: + return encodePlanDateCodecText{} + } + + return nil +} + +type encodePlanDateCodecBinary struct{} + +func (encodePlanDateCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + date, err := value.(DateValuer).DateValue() + if err != nil { + return nil, err + } + + if !date.Valid { + return nil, nil + } + + var daysSinceDateEpoch int32 + switch date.InfinityModifier { + case Finite: + tUnix := time.Date(date.Time.Year(), date.Time.Month(), date.Time.Day(), 0, 0, 0, 0, time.UTC).Unix() + dateEpoch := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC).Unix() + + secSinceDateEpoch := tUnix - dateEpoch + daysSinceDateEpoch = int32(secSinceDateEpoch / 86400) + case Infinity: + daysSinceDateEpoch = infinityDayOffset + case NegativeInfinity: + daysSinceDateEpoch = negativeInfinityDayOffset + } + + return pgio.AppendInt32(buf, daysSinceDateEpoch), nil +} + +type encodePlanDateCodecText struct{} + +func (encodePlanDateCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { + date, err := value.(DateValuer).DateValue() + if err != nil { + return nil, err + } + + if !date.Valid { + return nil, nil + } + + switch date.InfinityModifier { + case Finite: + // Year 0000 is 1 BC + bc := false + year := date.Time.Year() + if year <= 0 { + year = -year + 1 + bc = true + } + + buf = strconv.AppendInt(buf, int64(year), 10) + buf = append(buf, '-') + buf = strconv.AppendInt(buf, int64(date.Time.Month()), 10) + buf = append(buf, '-') + buf = strconv.AppendInt(buf, int64(date.Time.Day()), 10) + + if bc { + buf = append(buf, " BC"...) + } + case Infinity: + buf = append(buf, "infinity"...) + case NegativeInfinity: + buf = append(buf, "-infinity"...) + } + + return buf, nil +} + +func (DateCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case DateScanner: + return scanPlanBinaryDateToDateScanner{} + } + case TextFormatCode: + switch target.(type) { + case DateScanner: + return scanPlanTextAnyToDateScanner{} + } + } + + return nil +} + +type scanPlanBinaryDateToDateScanner struct{} + +func (scanPlanBinaryDateToDateScanner) Scan(src []byte, dst any) error { + scanner := (dst).(DateScanner) + + if src == nil { + return scanner.ScanDate(Date{}) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for date: %v", len(src)) + } + + dayOffset := int32(binary.BigEndian.Uint32(src)) + + switch dayOffset { + case infinityDayOffset: + return scanner.ScanDate(Date{InfinityModifier: Infinity, Valid: true}) + case negativeInfinityDayOffset: + return scanner.ScanDate(Date{InfinityModifier: -Infinity, Valid: true}) + default: + t := time.Date(2000, 1, int(1+dayOffset), 0, 0, 0, 0, time.UTC) + return scanner.ScanDate(Date{Time: t, Valid: true}) + } +} + +type scanPlanTextAnyToDateScanner struct{} + +func (scanPlanTextAnyToDateScanner) Scan(src []byte, dst any) error { + scanner := (dst).(DateScanner) + + if src == nil { + return scanner.ScanDate(Date{}) + } + + sbuf := string(src) + switch sbuf { + case "infinity": + return scanner.ScanDate(Date{InfinityModifier: Infinity, Valid: true}) + case "-infinity": + return scanner.ScanDate(Date{InfinityModifier: -Infinity, Valid: true}) + default: + if len(sbuf) >= 10 { + year, err := strconv.ParseInt(sbuf[0:4], 10, 32) + if err != nil { + return fmt.Errorf("cannot parse year: %v", err) + } + month, err := strconv.ParseInt(sbuf[5:7], 10, 32) + if err != nil { + return fmt.Errorf("cannot parse month: %v", err) + } + day, err := strconv.ParseInt(sbuf[8:10], 10, 32) + if err != nil { + return fmt.Errorf("cannot parse day: %v", err) + } + + if len(sbuf) == 13 && sbuf[11:] == "BC" { + year = -year + 1 + } + + t := time.Date(int(year), time.Month(month), int(day), 0, 0, 0, 0, time.UTC) + return scanner.ScanDate(Date{Time: t, Valid: true}) + } else { + return fmt.Errorf("date too short") + } + } +} + +func (c DateCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) +} + +func (c DateCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var date Date + err := codecScan(c, m, oid, format, src, &date) + if err != nil { + return nil, err + } + + if date.Valid { + switch date.InfinityModifier { + case Finite: + return date.Time, nil + case Infinity: + return "infinity", nil + case NegativeInfinity: + return "-infinity", nil + } + } + + return nil, nil +} diff --git a/pgtype/date_test.go b/pgtype/date_test.go new file mode 100644 index 00000000..de61fd72 --- /dev/null +++ b/pgtype/date_test.go @@ -0,0 +1,87 @@ +package pgtype_test + +import ( + "context" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" +) + +func isExpectedEqTime(a any) func(any) bool { + return func(v any) bool { + at := a.(time.Time) + vt := v.(time.Time) + + return at.Equal(vt) + } +} + +func TestDateCodec(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "date", []pgxtest.ValueRoundTripTest{ + {time.Date(-100, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(-100, 1, 1, 0, 0, 0, 0, time.UTC))}, + {time.Date(-1, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(-1, 1, 1, 0, 0, 0, 0, time.UTC))}, + {time.Date(0, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(0, 1, 1, 0, 0, 0, 0, time.UTC))}, + {time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC))}, + {time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC))}, + {time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC))}, + {time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC))}, + {time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC))}, + {time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC))}, + {time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC))}, + {pgtype.Date{InfinityModifier: pgtype.Infinity, Valid: true}, new(pgtype.Date), isExpectedEq(pgtype.Date{InfinityModifier: pgtype.Infinity, Valid: true})}, + {pgtype.Date{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, new(pgtype.Date), isExpectedEq(pgtype.Date{InfinityModifier: pgtype.NegativeInfinity, Valid: true})}, + {pgtype.Date{}, new(pgtype.Date), isExpectedEq(pgtype.Date{})}, + {nil, new(*time.Time), isExpectedEq((*time.Time)(nil))}, + }) +} + +func TestDateMarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Date + result string + }{ + {source: pgtype.Date{}, result: "null"}, + {source: pgtype.Date{Time: time.Date(2012, 3, 29, 0, 0, 0, 0, time.UTC), Valid: true}, result: "\"2012-03-29\""}, + {source: pgtype.Date{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.FixedZone("", -6*60*60)), Valid: true}, result: "\"2012-03-29\""}, + {source: pgtype.Date{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.FixedZone("", -6*60*60)), Valid: true}, result: "\"2012-03-29\""}, + {source: pgtype.Date{InfinityModifier: pgtype.Infinity, Valid: true}, result: "\"infinity\""}, + {source: pgtype.Date{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, result: "\"-infinity\""}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + +func TestDateUnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Date + }{ + {source: "null", result: pgtype.Date{}}, + {source: "\"2012-03-29\"", result: pgtype.Date{Time: time.Date(2012, 3, 29, 0, 0, 0, 0, time.UTC), Valid: true}}, + {source: "\"2012-03-29\"", result: pgtype.Date{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.FixedZone("", -6*60*60)), Valid: true}}, + {source: "\"2012-03-29\"", result: pgtype.Date{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.FixedZone("", -6*60*60)), Valid: true}}, + {source: "\"infinity\"", result: pgtype.Date{InfinityModifier: pgtype.Infinity, Valid: true}}, + {source: "\"-infinity\"", result: pgtype.Date{InfinityModifier: pgtype.NegativeInfinity, Valid: true}}, + } + for i, tt := range successfulTests { + var r pgtype.Date + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r.Time.Year() != tt.result.Time.Year() || r.Time.Month() != tt.result.Time.Month() || r.Time.Day() != tt.result.Time.Day() || r.Valid != tt.result.Valid || r.InfinityModifier != tt.result.InfinityModifier { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/pgtype/doc.go b/pgtype/doc.go new file mode 100644 index 00000000..834aa0b6 --- /dev/null +++ b/pgtype/doc.go @@ -0,0 +1,151 @@ +// Package pgtype converts between Go and PostgreSQL values. +/* +The primary type is the Map type. It is a map of PostgreSQL types identified by OID (object ID) to a Codec. A Codec is +responsible for converting between Go and PostgreSQL values. NewMap creates a Map with all supported standard PostgreSQL +types already registered. Additional types can be registered with Map.RegisterType. + +Use Map.Scan and Map.Encode to decode PostgreSQL values to Go and encode Go values to PostgreSQL respectively. + +Base Type Mapping + +pgtype maps between all common base types directly between Go and PostgreSQL. In particular: + + Go PostgreSQL + ----------------------- + string varchar + text + + // Integers are automatically be converted to any other integer type if + // it can be done without overflow or underflow. + int8 + int16 smallint + int32 int + int64 bigint + int + uint8 + uint16 + uint32 + uint64 + uint + + // Floats are strict and do not automatically convert like integers. + float32 float4 + float64 float8 + + time.Time date + timestamp + timestamptz + + netip.Addr inet + netip.Prefix cidr + + []byte bytea + +Null Values + +pgtype can map NULLs in two ways. The first is types that can directly represent NULL such as Int4. They work in a +similar fashion to database/sql. The second is to use a pointer to a pointer. + + var foo pgtype.Text + var bar *string + err := conn.QueryRow("select foo, bar from widgets where id=$1", 42).Scan(&foo, &bar) + if err != nil { + return err + } + +JSON Support + +pgtype automatically marshals and unmarshals data from json and jsonb PostgreSQL types. + +Array Support + +ArrayCodec implements support for arrays. If pgtype supports type T then it can easily support []T by registering an +ArrayCodec for the appropriate PostgreSQL OID. In addition, Array[T] type can support multi-dimensional arrays. + +Composite Support + +CompositeCodec implements support for PostgreSQL composite types. Go structs can be scanned into if the public fields of +the struct are in the exact order and type of the PostgreSQL type or by implementing CompositeIndexScanner and +CompositeIndexGetter. + +Enum Support + +PostgreSQL enums can usually be treated as text. However, EnumCodec implements support for interning strings which can +reduce memory usage. + +Array, Composite, and Enum Type Registration + +Array, composite, and enum types can be easily registered from a pgx.Conn with the LoadType method. + +Extending Existing Type Support + +Generally, all Codecs will support interfaces that can be implemented to enable scanning and encoding. For example, +PointCodec can use any Go type that implements the PointScanner and PointValuer interfaces. So rather than use +pgtype.Point and application can directly use its own point type with pgtype as long as it implements those interfaces. + +See example_custom_type_test.go for an example of a custom type for the PostgreSQL point type. + +Sometimes pgx supports a PostgreSQL type such as numeric but the Go type is in an external package that does not have +pgx support such as github.com/shopspring/decimal. These types can be registered with pgtype with custom conversion +logic. See https://github.com/jackc/pgx-shopspring-decimal and https://github.com/jackc/pgx-gofrs-uuid for a example +integrations. + +Entirely New Type Support + +If the PostgreSQL type is not already supported then an OID / Codec mapping can be registered with Map.RegisterType. +There is no difference between a Codec defined and registered by the application and a Codec built in to pgtype. See any +of the Codecs in pgtype for Codec examples and for examples of type registration. + +Encoding Unknown Types + +pgtype works best when the OID of the PostgreSQL type is known. But in some cases such as using the simple protocol the +OID is unknown. In this case Map.RegisterDefaultPgType can be used to register an assumed OID for a particular Go type. + +Renamed Types + +If pgtype does not recognize a type and that type is a renamed simple type simple (e.g. type MyInt32 int32) pgtype acts +as if it is the underlying type. It currently cannot automatically detect the underlying type of renamed structs (eg.g. +type MyTime time.Time). + +Compatibility with database/sql + +pgtype also includes support for custom types implementing the database/sql.Scanner and database/sql/driver.Valuer +interfaces. + +Child Records + +pgtype's support for arrays and composite records can be used to load records and their children in a single query. See +example_child_records_test.go for an example. + +Overview of Scanning Implementation + +The first step is to use the OID to lookup the correct Codec. If the OID is unavailable, Map will try to find the OID +from previous calls of Map.RegisterDefaultPgType. The Map will call the Codec's PlanScan method to get a plan for +scanning into the Go value. A Codec will support scanning into one or more Go types. Oftentime these Go types are +interfaces rather than explicit types. For example, PointCodec can use any Go type that implments the PointScanner and +PointValuer interfaces. + +If a Go value is not supported directly by a Codec then Map will try wrapping it with additional logic and try again. +For example, Int8Codec does not support scanning into a renamed type (e.g. type myInt64 int64). But Map will detect that +myInt64 is a renamed type and create a plan that converts the value to the underlying int64 type and then passes that to +the Codec (see TryFindUnderlyingTypeScanPlan). + +These plan wrappers are contained in Map.TryWrapScanPlanFuncs. By default these contain shared logic to handle renamed +types, pointers to pointers, slices, composite types, etc. Additional plan wrappers can be added to seamlessly integrate +types that do not support pgx directly. For example, the before mentioned +https://github.com/jackc/pgx-shopspring-decimal package detects decimal.Decimal values, wraps them in something +implementing NumericScanner and passes that to the Codec. + +Map.Scan and Map.Encode are convenience methods that wrap Map.PlanScan and Map.PlanEncode. Determining how to scan or +encode a particular type may be a time consuming operation. Hence the planning and execution steps of a conversion are +internally separated. + +Reducing Compiled Binary Size + +pgx.QueryExecModeExec and pgx.QueryExecModeSimpleProtocol require the default PostgreSQL type to be registered for each +Go type used as a query parameter. By default pgx does this for all supported types and their array variants. If an +application does not use those query execution modes or manually registers the default PostgreSQL type for the types it +uses as query parameters it can use the build tag nopgxregisterdefaulttypes. This omits the default type registration +and reduces the compiled binary size by ~2MB. +*/ +package pgtype diff --git a/pgtype/enum_codec.go b/pgtype/enum_codec.go new file mode 100644 index 00000000..5e787c1e --- /dev/null +++ b/pgtype/enum_codec.go @@ -0,0 +1,109 @@ +package pgtype + +import ( + "database/sql/driver" + "fmt" +) + +// EnumCodec is a codec that caches the strings it decodes. If the same string is read multiple times only one copy is +// allocated. These strings are only garbage collected when the EnumCodec is garbage collected. EnumCodec can be used +// for any text type not only enums, but it should only be used when there are a small number of possible values. +type EnumCodec struct { + membersMap map[string]string // map to quickly lookup member and reuse string instead of allocating +} + +func (EnumCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (EnumCodec) PreferredFormat() int16 { + return TextFormatCode +} + +func (EnumCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch format { + case TextFormatCode, BinaryFormatCode: + switch value.(type) { + case string: + return encodePlanTextCodecString{} + case []byte: + return encodePlanTextCodecByteSlice{} + case TextValuer: + return encodePlanTextCodecTextValuer{} + } + } + + return nil +} + +func (c *EnumCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case TextFormatCode, BinaryFormatCode: + switch target.(type) { + case *string: + return &scanPlanTextAnyToEnumString{codec: c} + case *[]byte: + return scanPlanAnyToNewByteSlice{} + case TextScanner: + return &scanPlanTextAnyToEnumTextScanner{codec: c} + } + } + + return nil +} + +func (c *EnumCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return c.DecodeValue(m, oid, format, src) +} + +func (c *EnumCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + return c.lookupAndCacheString(src), nil +} + +// lookupAndCacheString looks for src in the members map. If it is not found it is added to the map. +func (c *EnumCodec) lookupAndCacheString(src []byte) string { + if c.membersMap == nil { + c.membersMap = make(map[string]string) + } + + if s, found := c.membersMap[string(src)]; found { + return s + } + + s := string(src) + c.membersMap[s] = s + return s +} + +type scanPlanTextAnyToEnumString struct { + codec *EnumCodec +} + +func (plan *scanPlanTextAnyToEnumString) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + p := (dst).(*string) + *p = plan.codec.lookupAndCacheString(src) + + return nil +} + +type scanPlanTextAnyToEnumTextScanner struct { + codec *EnumCodec +} + +func (plan *scanPlanTextAnyToEnumTextScanner) Scan(src []byte, dst any) error { + scanner := (dst).(TextScanner) + + if src == nil { + return scanner.ScanText(Text{}) + } + + return scanner.ScanText(Text{String: plan.codec.lookupAndCacheString(src), Valid: true}) +} diff --git a/pgtype/enum_codec_test.go b/pgtype/enum_codec_test.go new file mode 100644 index 00000000..d064d49c --- /dev/null +++ b/pgtype/enum_codec_test.go @@ -0,0 +1,69 @@ +package pgtype_test + +import ( + "context" + "testing" + + pgx "github.com/jackc/pgx/v5" + "github.com/stretchr/testify/require" +) + +func TestEnumCodec(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + + _, err := conn.Exec(ctx, `drop type if exists enum_test; + +create type enum_test as enum ('foo', 'bar', 'baz');`) + require.NoError(t, err) + defer conn.Exec(ctx, "drop type enum_test") + + dt, err := conn.LoadType(ctx, "enum_test") + require.NoError(t, err) + + conn.TypeMap().RegisterType(dt) + + var s string + err = conn.QueryRow(ctx, `select 'foo'::enum_test`).Scan(&s) + require.NoError(t, err) + require.Equal(t, "foo", s) + + err = conn.QueryRow(ctx, `select $1::enum_test`, "bar").Scan(&s) + require.NoError(t, err) + require.Equal(t, "bar", s) + + err = conn.QueryRow(ctx, `select 'foo'::enum_test`).Scan(&s) + require.NoError(t, err) + require.Equal(t, "foo", s) + + err = conn.QueryRow(ctx, `select $1::enum_test`, "bar").Scan(&s) + require.NoError(t, err) + require.Equal(t, "bar", s) + + err = conn.QueryRow(ctx, `select 'baz'::enum_test`).Scan(&s) + require.NoError(t, err) + require.Equal(t, "baz", s) + }) +} + +func TestEnumCodecValues(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + + _, err := conn.Exec(ctx, `drop type if exists enum_test; + +create type enum_test as enum ('foo', 'bar', 'baz');`) + require.NoError(t, err) + defer conn.Exec(ctx, "drop type enum_test") + + dt, err := conn.LoadType(ctx, "enum_test") + require.NoError(t, err) + + conn.TypeMap().RegisterType(dt) + + rows, err := conn.Query(ctx, `select 'foo'::enum_test`) + require.NoError(t, err) + require.True(t, rows.Next()) + values, err := rows.Values() + require.NoError(t, err) + require.Equal(t, values, []any{"foo"}) + }) +} diff --git a/pgtype/example_child_records_test.go b/pgtype/example_child_records_test.go new file mode 100644 index 00000000..9a4218ba --- /dev/null +++ b/pgtype/example_child_records_test.go @@ -0,0 +1,103 @@ +package pgtype_test + +import ( + "context" + "fmt" + "os" + "time" + + "github.com/jackc/pgx/v5" +) + +type Player struct { + Name string + Position string +} + +type Team struct { + Name string + Players []Player +} + +// This example uses a single query to return parent and child records. +func Example_childRecords() { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + fmt.Printf("Unable to establish connection: %v", err) + return + } + + if conn.PgConn().ParameterStatus("crdb_version") != "" { + // Skip test / example when running on CockroachDB. Since an example can't be skipped fake success instead. + fmt.Println(`Alpha + Adam: wing + Bill: halfback + Charlie: fullback +Beta + Don: halfback + Edgar: halfback + Frank: fullback`) + return + } + + // Setup example schema and data. + _, err = conn.Exec(ctx, ` +create temporary table teams ( + name text primary key +); + +create temporary table players ( + name text primary key, + team_name text, + position text +); + +insert into teams (name) values + ('Alpha'), + ('Beta'); + +insert into players (name, team_name, position) values + ('Adam', 'Alpha', 'wing'), + ('Bill', 'Alpha', 'halfback'), + ('Charlie', 'Alpha', 'fullback'), + ('Don', 'Beta', 'halfback'), + ('Edgar', 'Beta', 'halfback'), + ('Frank', 'Beta', 'fullback') +`) + if err != nil { + fmt.Printf("Unable to setup example schema and data: %v", err) + return + } + + rows, _ := conn.Query(ctx, ` +select t.name, + (select array_agg(row(p.name, position) order by p.name) from players p where p.team_name = t.name) +from teams t +order by t.name +`) + teams, err := pgx.CollectRows(rows, pgx.RowToStructByPos[Team]) + if err != nil { + fmt.Printf("CollectRows error: %v", err) + return + } + + for _, team := range teams { + fmt.Println(team.Name) + for _, player := range team.Players { + fmt.Printf(" %s: %s\n", player.Name, player.Position) + } + } + + // Output: + // Alpha + // Adam: wing + // Bill: halfback + // Charlie: fullback + // Beta + // Don: halfback + // Edgar: halfback + // Frank: fullback +} diff --git a/pgtype/example_custom_type_test.go b/pgtype/example_custom_type_test.go new file mode 100644 index 00000000..ceb9a0aa --- /dev/null +++ b/pgtype/example_custom_type_test.go @@ -0,0 +1,75 @@ +package pgtype_test + +import ( + "context" + "fmt" + "os" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" +) + +// Point represents a point that may be null. +type Point struct { + X, Y float32 // Coordinates of point + Valid bool +} + +func (p *Point) ScanPoint(v pgtype.Point) error { + *p = Point{ + X: float32(v.P.X), + Y: float32(v.P.Y), + Valid: v.Valid, + } + return nil +} + +func (p Point) PointValue() (pgtype.Point, error) { + return pgtype.Point{ + P: pgtype.Vec2{X: float64(p.X), Y: float64(p.Y)}, + Valid: true, + }, nil +} + +func (src *Point) String() string { + if !src.Valid { + return "null point" + } + + return fmt.Sprintf("%.1f, %.1f", src.X, src.Y) +} + +func Example_customType() { + conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + fmt.Printf("Unable to establish connection: %v", err) + return + } + defer conn.Close(context.Background()) + + if conn.PgConn().ParameterStatus("crdb_version") != "" { + // Skip test / example when running on CockroachDB which doesn't support the point type. Since an example can't be + // skipped fake success instead. + fmt.Println("null point") + fmt.Println("1.5, 2.5") + return + } + + p := &Point{} + err = conn.QueryRow(context.Background(), "select null::point").Scan(p) + if err != nil { + fmt.Println(err) + return + } + fmt.Println(p) + + err = conn.QueryRow(context.Background(), "select point(1.5,2.5)").Scan(p) + if err != nil { + fmt.Println(err) + return + } + fmt.Println(p) + // Output: + // null point + // 1.5, 2.5 +} diff --git a/example_json_test.go b/pgtype/example_json_test.go similarity index 88% rename from example_json_test.go rename to pgtype/example_json_test.go index 33bd7519..98fb675a 100644 --- a/example_json_test.go +++ b/pgtype/example_json_test.go @@ -1,14 +1,14 @@ -package pgx_test +package pgtype_test import ( "context" "fmt" "os" - "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v5" ) -func Example_JSON() { +func Example_json() { conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) if err != nil { fmt.Printf("Unable to establish connection: %v", err) diff --git a/pgtype/float4.go b/pgtype/float4.go new file mode 100644 index 00000000..2540f9e5 --- /dev/null +++ b/pgtype/float4.go @@ -0,0 +1,295 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type Float4 struct { + Float32 float32 + Valid bool +} + +// ScanFloat64 implements the Float64Scanner interface. +func (f *Float4) ScanFloat64(n Float8) error { + *f = Float4{Float32: float32(n.Float64), Valid: n.Valid} + return nil +} + +func (f Float4) Float64Value() (Float8, error) { + return Float8{Float64: float64(f.Float32), Valid: f.Valid}, nil +} + +func (f *Float4) ScanInt64(n Int8) error { + *f = Float4{Float32: float32(n.Int64), Valid: n.Valid} + return nil +} + +func (f Float4) Int64Value() (Int8, error) { + return Int8{Int64: int64(f.Float32), Valid: f.Valid}, nil +} + +// Scan implements the database/sql Scanner interface. +func (f *Float4) Scan(src any) error { + if src == nil { + *f = Float4{} + return nil + } + + switch src := src.(type) { + case float64: + *f = Float4{Float32: float32(src), Valid: true} + return nil + case string: + n, err := strconv.ParseFloat(string(src), 32) + if err != nil { + return err + } + *f = Float4{Float32: float32(n), Valid: true} + return nil + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (f Float4) Value() (driver.Value, error) { + if !f.Valid { + return nil, nil + } + return float64(f.Float32), nil +} + +type Float4Codec struct{} + +func (Float4Codec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (Float4Codec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (Float4Codec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch format { + case BinaryFormatCode: + switch value.(type) { + case float32: + return encodePlanFloat4CodecBinaryFloat32{} + case Float64Valuer: + return encodePlanFloat4CodecBinaryFloat64Valuer{} + case Int64Valuer: + return encodePlanFloat4CodecBinaryInt64Valuer{} + } + case TextFormatCode: + switch value.(type) { + case float32: + return encodePlanTextFloat32{} + case Float64Valuer: + return encodePlanTextFloat64Valuer{} + case Int64Valuer: + return encodePlanTextInt64Valuer{} + } + } + + return nil +} + +type encodePlanFloat4CodecBinaryFloat32 struct{} + +func (encodePlanFloat4CodecBinaryFloat32) Encode(value any, buf []byte) (newBuf []byte, err error) { + n := value.(float32) + return pgio.AppendUint32(buf, math.Float32bits(n)), nil +} + +type encodePlanTextFloat32 struct{} + +func (encodePlanTextFloat32) Encode(value any, buf []byte) (newBuf []byte, err error) { + n := value.(float32) + return append(buf, strconv.FormatFloat(float64(n), 'f', -1, 32)...), nil +} + +type encodePlanFloat4CodecBinaryFloat64Valuer struct{} + +func (encodePlanFloat4CodecBinaryFloat64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + n, err := value.(Float64Valuer).Float64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + return pgio.AppendUint32(buf, math.Float32bits(float32(n.Float64))), nil +} + +type encodePlanFloat4CodecBinaryInt64Valuer struct{} + +func (encodePlanFloat4CodecBinaryInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + f := float32(n.Int64) + return pgio.AppendUint32(buf, math.Float32bits(f)), nil +} + +func (Float4Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case *float32: + return scanPlanBinaryFloat4ToFloat32{} + case Float64Scanner: + return scanPlanBinaryFloat4ToFloat64Scanner{} + case Int64Scanner: + return scanPlanBinaryFloat4ToInt64Scanner{} + case TextScanner: + return scanPlanBinaryFloat4ToTextScanner{} + } + case TextFormatCode: + switch target.(type) { + case *float32: + return scanPlanTextAnyToFloat32{} + case Float64Scanner: + return scanPlanTextAnyToFloat64Scanner{} + case Int64Scanner: + return scanPlanTextAnyToInt64Scanner{} + } + } + + return nil +} + +type scanPlanBinaryFloat4ToFloat32 struct{} + +func (scanPlanBinaryFloat4ToFloat32) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for float4: %v", len(src)) + } + + n := int32(binary.BigEndian.Uint32(src)) + f := (dst).(*float32) + *f = math.Float32frombits(uint32(n)) + + return nil +} + +type scanPlanBinaryFloat4ToFloat64Scanner struct{} + +func (scanPlanBinaryFloat4ToFloat64Scanner) Scan(src []byte, dst any) error { + s := (dst).(Float64Scanner) + + if src == nil { + return s.ScanFloat64(Float8{}) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for float4: %v", len(src)) + } + + n := int32(binary.BigEndian.Uint32(src)) + return s.ScanFloat64(Float8{Float64: float64(math.Float32frombits(uint32(n))), Valid: true}) +} + +type scanPlanBinaryFloat4ToInt64Scanner struct{} + +func (scanPlanBinaryFloat4ToInt64Scanner) Scan(src []byte, dst any) error { + s := (dst).(Int64Scanner) + + if src == nil { + return s.ScanInt64(Int8{}) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for float4: %v", len(src)) + } + + ui32 := int32(binary.BigEndian.Uint32(src)) + f32 := math.Float32frombits(uint32(ui32)) + i64 := int64(f32) + if f32 != float32(i64) { + return fmt.Errorf("cannot losslessly convert %v to int64", f32) + } + + return s.ScanInt64(Int8{Int64: i64, Valid: true}) +} + +type scanPlanBinaryFloat4ToTextScanner struct{} + +func (scanPlanBinaryFloat4ToTextScanner) Scan(src []byte, dst any) error { + s := (dst).(TextScanner) + + if src == nil { + return s.ScanText(Text{}) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for float4: %v", len(src)) + } + + ui32 := int32(binary.BigEndian.Uint32(src)) + f32 := math.Float32frombits(uint32(ui32)) + + return s.ScanText(Text{String: strconv.FormatFloat(float64(f32), 'f', -1, 32), Valid: true}) +} + +type scanPlanTextAnyToFloat32 struct{} + +func (scanPlanTextAnyToFloat32) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + n, err := strconv.ParseFloat(string(src), 32) + if err != nil { + return err + } + + f := (dst).(*float32) + *f = float32(n) + + return nil +} + +func (c Float4Codec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + var n float64 + err := codecScan(c, m, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil +} + +func (c Float4Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var n float32 + err := codecScan(c, m, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil +} diff --git a/pgtype/float4_test.go b/pgtype/float4_test.go new file mode 100644 index 00000000..f155ed97 --- /dev/null +++ b/pgtype/float4_test.go @@ -0,0 +1,23 @@ +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" +) + +func TestFloat4Codec(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "float4", []pgxtest.ValueRoundTripTest{ + {pgtype.Float4{Float32: -1, Valid: true}, new(pgtype.Float4), isExpectedEq(pgtype.Float4{Float32: -1, Valid: true})}, + {pgtype.Float4{Float32: 0, Valid: true}, new(pgtype.Float4), isExpectedEq(pgtype.Float4{Float32: 0, Valid: true})}, + {pgtype.Float4{Float32: 1, Valid: true}, new(pgtype.Float4), isExpectedEq(pgtype.Float4{Float32: 1, Valid: true})}, + {float32(0.00001), new(float32), isExpectedEq(float32(0.00001))}, + {float32(9999.99), new(float32), isExpectedEq(float32(9999.99))}, + {pgtype.Float4{}, new(pgtype.Float4), isExpectedEq(pgtype.Float4{})}, + {int64(1), new(int64), isExpectedEq(int64(1))}, + {"1.23", new(string), isExpectedEq("1.23")}, + {nil, new(*float32), isExpectedEq((*float32)(nil))}, + }) +} diff --git a/pgtype/float8.go b/pgtype/float8.go new file mode 100644 index 00000000..6af27d6f --- /dev/null +++ b/pgtype/float8.go @@ -0,0 +1,341 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type Float64Scanner interface { + ScanFloat64(Float8) error +} + +type Float64Valuer interface { + Float64Value() (Float8, error) +} + +type Float8 struct { + Float64 float64 + Valid bool +} + +// ScanFloat64 implements the Float64Scanner interface. +func (f *Float8) ScanFloat64(n Float8) error { + *f = n + return nil +} + +func (f Float8) Float64Value() (Float8, error) { + return f, nil +} + +func (f *Float8) ScanInt64(n Int8) error { + *f = Float8{Float64: float64(n.Int64), Valid: n.Valid} + return nil +} + +func (f Float8) Int64Value() (Int8, error) { + return Int8{Int64: int64(f.Float64), Valid: f.Valid}, nil +} + +// Scan implements the database/sql Scanner interface. +func (f *Float8) Scan(src any) error { + if src == nil { + *f = Float8{} + return nil + } + + switch src := src.(type) { + case float64: + *f = Float8{Float64: src, Valid: true} + return nil + case string: + n, err := strconv.ParseFloat(string(src), 64) + if err != nil { + return err + } + *f = Float8{Float64: n, Valid: true} + return nil + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (f Float8) Value() (driver.Value, error) { + if !f.Valid { + return nil, nil + } + return f.Float64, nil +} + +type Float8Codec struct{} + +func (Float8Codec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (Float8Codec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (Float8Codec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch format { + case BinaryFormatCode: + switch value.(type) { + case float64: + return encodePlanFloat8CodecBinaryFloat64{} + case Float64Valuer: + return encodePlanFloat8CodecBinaryFloat64Valuer{} + case Int64Valuer: + return encodePlanFloat8CodecBinaryInt64Valuer{} + } + case TextFormatCode: + switch value.(type) { + case float64: + return encodePlanTextFloat64{} + case Float64Valuer: + return encodePlanTextFloat64Valuer{} + case Int64Valuer: + return encodePlanTextInt64Valuer{} + } + } + + return nil +} + +type encodePlanFloat8CodecBinaryFloat64 struct{} + +func (encodePlanFloat8CodecBinaryFloat64) Encode(value any, buf []byte) (newBuf []byte, err error) { + n := value.(float64) + return pgio.AppendUint64(buf, math.Float64bits(n)), nil +} + +type encodePlanTextFloat64 struct{} + +func (encodePlanTextFloat64) Encode(value any, buf []byte) (newBuf []byte, err error) { + n := value.(float64) + return append(buf, strconv.FormatFloat(n, 'f', -1, 64)...), nil +} + +type encodePlanFloat8CodecBinaryFloat64Valuer struct{} + +func (encodePlanFloat8CodecBinaryFloat64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + n, err := value.(Float64Valuer).Float64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + return pgio.AppendUint64(buf, math.Float64bits(n.Float64)), nil +} + +type encodePlanTextFloat64Valuer struct{} + +func (encodePlanTextFloat64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + n, err := value.(Float64Valuer).Float64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + return append(buf, strconv.FormatFloat(n.Float64, 'f', -1, 64)...), nil +} + +type encodePlanFloat8CodecBinaryInt64Valuer struct{} + +func (encodePlanFloat8CodecBinaryInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + f := float64(n.Int64) + return pgio.AppendUint64(buf, math.Float64bits(f)), nil +} + +type encodePlanTextInt64Valuer struct{} + +func (encodePlanTextInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + return append(buf, strconv.FormatInt(n.Int64, 10)...), nil +} + +func (Float8Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case *float64: + return scanPlanBinaryFloat8ToFloat64{} + case Float64Scanner: + return scanPlanBinaryFloat8ToFloat64Scanner{} + case Int64Scanner: + return scanPlanBinaryFloat8ToInt64Scanner{} + case TextScanner: + return scanPlanBinaryFloat8ToTextScanner{} + } + case TextFormatCode: + switch target.(type) { + case *float64: + return scanPlanTextAnyToFloat64{} + case Float64Scanner: + return scanPlanTextAnyToFloat64Scanner{} + case Int64Scanner: + return scanPlanTextAnyToInt64Scanner{} + } + } + + return nil +} + +type scanPlanBinaryFloat8ToFloat64 struct{} + +func (scanPlanBinaryFloat8ToFloat64) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for float8: %v", len(src)) + } + + n := int64(binary.BigEndian.Uint64(src)) + f := (dst).(*float64) + *f = math.Float64frombits(uint64(n)) + + return nil +} + +type scanPlanBinaryFloat8ToFloat64Scanner struct{} + +func (scanPlanBinaryFloat8ToFloat64Scanner) Scan(src []byte, dst any) error { + s := (dst).(Float64Scanner) + + if src == nil { + return s.ScanFloat64(Float8{}) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for float8: %v", len(src)) + } + + n := int64(binary.BigEndian.Uint64(src)) + return s.ScanFloat64(Float8{Float64: math.Float64frombits(uint64(n)), Valid: true}) +} + +type scanPlanBinaryFloat8ToInt64Scanner struct{} + +func (scanPlanBinaryFloat8ToInt64Scanner) Scan(src []byte, dst any) error { + s := (dst).(Int64Scanner) + + if src == nil { + return s.ScanInt64(Int8{}) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for float8: %v", len(src)) + } + + ui64 := int64(binary.BigEndian.Uint64(src)) + f64 := math.Float64frombits(uint64(ui64)) + i64 := int64(f64) + if f64 != float64(i64) { + return fmt.Errorf("cannot losslessly convert %v to int64", f64) + } + + return s.ScanInt64(Int8{Int64: i64, Valid: true}) +} + +type scanPlanBinaryFloat8ToTextScanner struct{} + +func (scanPlanBinaryFloat8ToTextScanner) Scan(src []byte, dst any) error { + s := (dst).(TextScanner) + + if src == nil { + return s.ScanText(Text{}) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for float8: %v", len(src)) + } + + ui64 := int64(binary.BigEndian.Uint64(src)) + f64 := math.Float64frombits(uint64(ui64)) + + return s.ScanText(Text{String: strconv.FormatFloat(f64, 'f', -1, 64), Valid: true}) +} + +type scanPlanTextAnyToFloat64 struct{} + +func (scanPlanTextAnyToFloat64) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + n, err := strconv.ParseFloat(string(src), 64) + if err != nil { + return err + } + + f := (dst).(*float64) + *f = n + + return nil +} + +type scanPlanTextAnyToFloat64Scanner struct{} + +func (scanPlanTextAnyToFloat64Scanner) Scan(src []byte, dst any) error { + s := (dst).(Float64Scanner) + + if src == nil { + return s.ScanFloat64(Float8{}) + } + + n, err := strconv.ParseFloat(string(src), 64) + if err != nil { + return err + } + + return s.ScanFloat64(Float8{Float64: n, Valid: true}) +} + +func (c Float8Codec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return c.DecodeValue(m, oid, format, src) +} + +func (c Float8Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var n float64 + err := codecScan(c, m, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil +} diff --git a/pgtype/float8_test.go b/pgtype/float8_test.go new file mode 100644 index 00000000..496b718b --- /dev/null +++ b/pgtype/float8_test.go @@ -0,0 +1,23 @@ +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" +) + +func TestFloat8Codec(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "float8", []pgxtest.ValueRoundTripTest{ + {pgtype.Float8{Float64: -1, Valid: true}, new(pgtype.Float8), isExpectedEq(pgtype.Float8{Float64: -1, Valid: true})}, + {pgtype.Float8{Float64: 0, Valid: true}, new(pgtype.Float8), isExpectedEq(pgtype.Float8{Float64: 0, Valid: true})}, + {pgtype.Float8{Float64: 1, Valid: true}, new(pgtype.Float8), isExpectedEq(pgtype.Float8{Float64: 1, Valid: true})}, + {float64(0.00001), new(float64), isExpectedEq(float64(0.00001))}, + {float64(9999.99), new(float64), isExpectedEq(float64(9999.99))}, + {pgtype.Float8{}, new(pgtype.Float8), isExpectedEq(pgtype.Float8{})}, + {int64(1), new(int64), isExpectedEq(int64(1))}, + {"1.23", new(string), isExpectedEq("1.23")}, + {nil, new(*float64), isExpectedEq((*float64)(nil))}, + }) +} diff --git a/pgtype/hstore.go b/pgtype/hstore.go new file mode 100644 index 00000000..4743643e --- /dev/null +++ b/pgtype/hstore.go @@ -0,0 +1,461 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "encoding/binary" + "errors" + "fmt" + "strings" + "unicode" + "unicode/utf8" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type HstoreScanner interface { + ScanHstore(v Hstore) error +} + +type HstoreValuer interface { + HstoreValue() (Hstore, error) +} + +// Hstore represents an hstore column that can be null or have null values +// associated with its keys. +type Hstore map[string]*string + +func (h *Hstore) ScanHstore(v Hstore) error { + *h = v + return nil +} + +func (h Hstore) HstoreValue() (Hstore, error) { + return h, nil +} + +// Scan implements the database/sql Scanner interface. +func (h *Hstore) Scan(src any) error { + if src == nil { + *h = nil + return nil + } + + switch src := src.(type) { + case string: + return scanPlanTextAnyToHstoreScanner{}.Scan([]byte(src), h) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (h Hstore) Value() (driver.Value, error) { + if h == nil { + return nil, nil + } + + buf, err := HstoreCodec{}.PlanEncode(nil, 0, TextFormatCode, h).Encode(h, nil) + if err != nil { + return nil, err + } + return string(buf), err +} + +type HstoreCodec struct{} + +func (HstoreCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (HstoreCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (HstoreCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(HstoreValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return encodePlanHstoreCodecBinary{} + case TextFormatCode: + return encodePlanHstoreCodecText{} + } + + return nil +} + +type encodePlanHstoreCodecBinary struct{} + +func (encodePlanHstoreCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + hstore, err := value.(HstoreValuer).HstoreValue() + if err != nil { + return nil, err + } + + if hstore == nil { + return nil, nil + } + + buf = pgio.AppendInt32(buf, int32(len(hstore))) + + for k, v := range hstore { + buf = pgio.AppendInt32(buf, int32(len(k))) + buf = append(buf, k...) + + if v == nil { + buf = pgio.AppendInt32(buf, -1) + } else { + buf = pgio.AppendInt32(buf, int32(len(*v))) + buf = append(buf, (*v)...) + } + } + + return buf, nil +} + +type encodePlanHstoreCodecText struct{} + +func (encodePlanHstoreCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { + hstore, err := value.(HstoreValuer).HstoreValue() + if err != nil { + return nil, err + } + + if hstore == nil { + return nil, nil + } + + firstPair := true + + for k, v := range hstore { + if firstPair { + firstPair = false + } else { + buf = append(buf, ',') + } + + buf = append(buf, quoteHstoreElementIfNeeded(k)...) + buf = append(buf, "=>"...) + + if v == nil { + buf = append(buf, "NULL"...) + } else { + buf = append(buf, quoteHstoreElementIfNeeded(*v)...) + } + } + + return buf, nil +} + +func (HstoreCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case HstoreScanner: + return scanPlanBinaryHstoreToHstoreScanner{} + } + case TextFormatCode: + switch target.(type) { + case HstoreScanner: + return scanPlanTextAnyToHstoreScanner{} + } + } + + return nil +} + +type scanPlanBinaryHstoreToHstoreScanner struct{} + +func (scanPlanBinaryHstoreToHstoreScanner) Scan(src []byte, dst any) error { + scanner := (dst).(HstoreScanner) + + if src == nil { + return scanner.ScanHstore(Hstore{}) + } + + rp := 0 + + if len(src[rp:]) < 4 { + return fmt.Errorf("hstore incomplete %v", src) + } + pairCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + hstore := make(Hstore, pairCount) + + for i := 0; i < pairCount; i++ { + if len(src[rp:]) < 4 { + return fmt.Errorf("hstore incomplete %v", src) + } + keyLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + if len(src[rp:]) < keyLen { + return fmt.Errorf("hstore incomplete %v", src) + } + key := string(src[rp : rp+keyLen]) + rp += keyLen + + if len(src[rp:]) < 4 { + return fmt.Errorf("hstore incomplete %v", src) + } + valueLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + var valueBuf []byte + if valueLen >= 0 { + valueBuf = src[rp : rp+valueLen] + rp += valueLen + } + + var value Text + err := scanPlanTextAnyToTextScanner{}.Scan(valueBuf, &value) + if err != nil { + return err + } + + if value.Valid { + hstore[key] = &value.String + } else { + hstore[key] = nil + } + } + + return scanner.ScanHstore(hstore) +} + +type scanPlanTextAnyToHstoreScanner struct{} + +func (scanPlanTextAnyToHstoreScanner) Scan(src []byte, dst any) error { + scanner := (dst).(HstoreScanner) + + if src == nil { + return scanner.ScanHstore(Hstore{}) + } + + keys, values, err := parseHstore(string(src)) + if err != nil { + return err + } + + m := make(Hstore, len(keys)) + for i := range keys { + if values[i].Valid { + m[keys[i]] = &values[i].String + } else { + m[keys[i]] = nil + } + } + + return scanner.ScanHstore(m) +} + +func (c HstoreCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) +} + +func (c HstoreCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var hstore Hstore + err := codecScan(c, m, oid, format, src, &hstore) + if err != nil { + return nil, err + } + return hstore, nil +} + +var quoteHstoreReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`) + +func quoteHstoreElement(src string) string { + return `"` + quoteArrayReplacer.Replace(src) + `"` +} + +func quoteHstoreElementIfNeeded(src string) string { + if src == "" || (len(src) == 4 && strings.ToLower(src) == "null") || strings.ContainsAny(src, ` {},"\=>`) { + return quoteArrayElement(src) + } + return src +} + +const ( + hsPre = iota + hsKey + hsSep + hsVal + hsNul + hsNext +) + +type hstoreParser struct { + str string + pos int +} + +func newHSP(in string) *hstoreParser { + return &hstoreParser{ + pos: 0, + str: in, + } +} + +func (p *hstoreParser) Consume() (r rune, end bool) { + if p.pos >= len(p.str) { + end = true + return + } + r, w := utf8.DecodeRuneInString(p.str[p.pos:]) + p.pos += w + return +} + +func (p *hstoreParser) Peek() (r rune, end bool) { + if p.pos >= len(p.str) { + end = true + return + } + r, _ = utf8.DecodeRuneInString(p.str[p.pos:]) + return +} + +// parseHstore parses the string representation of an hstore column (the same +// you would get from an ordinary SELECT) into two slices of keys and values. it +// is used internally in the default parsing of hstores. +func parseHstore(s string) (k []string, v []Text, err error) { + if s == "" { + return + } + + buf := bytes.Buffer{} + keys := []string{} + values := []Text{} + p := newHSP(s) + + r, end := p.Consume() + state := hsPre + + for !end { + switch state { + case hsPre: + if r == '"' { + state = hsKey + } else { + err = errors.New("String does not begin with \"") + } + case hsKey: + switch r { + case '"': //End of the key + keys = append(keys, buf.String()) + buf = bytes.Buffer{} + state = hsSep + case '\\': //Potential escaped character + n, end := p.Consume() + switch { + case end: + err = errors.New("Found EOS in key, expecting character or \"") + case n == '"', n == '\\': + buf.WriteRune(n) + default: + buf.WriteRune(r) + buf.WriteRune(n) + } + default: //Any other character + buf.WriteRune(r) + } + case hsSep: + if r == '=' { + r, end = p.Consume() + switch { + case end: + err = errors.New("Found EOS after '=', expecting '>'") + case r == '>': + r, end = p.Consume() + switch { + case end: + err = errors.New("Found EOS after '=>', expecting '\"' or 'NULL'") + case r == '"': + state = hsVal + case r == 'N': + state = hsNul + default: + err = fmt.Errorf("Invalid character '%c' after '=>', expecting '\"' or 'NULL'", r) + } + default: + err = fmt.Errorf("Invalid character after '=', expecting '>'") + } + } else { + err = fmt.Errorf("Invalid character '%c' after value, expecting '='", r) + } + case hsVal: + switch r { + case '"': //End of the value + values = append(values, Text{String: buf.String(), Valid: true}) + buf = bytes.Buffer{} + state = hsNext + case '\\': //Potential escaped character + n, end := p.Consume() + switch { + case end: + err = errors.New("Found EOS in key, expecting character or \"") + case n == '"', n == '\\': + buf.WriteRune(n) + default: + buf.WriteRune(r) + buf.WriteRune(n) + } + default: //Any other character + buf.WriteRune(r) + } + case hsNul: + nulBuf := make([]rune, 3) + nulBuf[0] = r + for i := 1; i < 3; i++ { + r, end = p.Consume() + if end { + err = errors.New("Found EOS in NULL value") + return + } + nulBuf[i] = r + } + if nulBuf[0] == 'U' && nulBuf[1] == 'L' && nulBuf[2] == 'L' { + values = append(values, Text{}) + state = hsNext + } else { + err = fmt.Errorf("Invalid NULL value: 'N%s'", string(nulBuf)) + } + case hsNext: + if r == ',' { + r, end = p.Consume() + switch { + case end: + err = errors.New("Found EOS after ',', expcting space") + case (unicode.IsSpace(r)): + r, end = p.Consume() + state = hsKey + default: + err = fmt.Errorf("Invalid character '%c' after ', ', expecting \"", r) + } + } else { + err = fmt.Errorf("Invalid character '%c' after value, expecting ','", r) + } + } + + if err != nil { + return + } + r, end = p.Consume() + } + if state != hsNext { + err = errors.New("Improperly formatted hstore") + return + } + k = keys + v = values + return +} diff --git a/pgtype/hstore_test.go b/pgtype/hstore_test.go new file mode 100644 index 00000000..f8684bf7 --- /dev/null +++ b/pgtype/hstore_test.go @@ -0,0 +1,182 @@ +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" +) + +func isExpectedEqMapStringString(a any) func(any) bool { + return func(v any) bool { + am := a.(map[string]string) + vm := v.(map[string]string) + + if len(am) != len(vm) { + return false + } + + for k, v := range am { + if vm[k] != v { + return false + } + } + + return true + } +} + +func isExpectedEqMapStringPointerString(a any) func(any) bool { + return func(v any) bool { + am := a.(map[string]*string) + vm := v.(map[string]*string) + + if len(am) != len(vm) { + return false + } + + for k, v := range am { + if (vm[k] == nil) != (v == nil) { + return false + } + + if v != nil && *vm[k] != *v { + return false + } + } + + return true + } +} + +func TestHstoreCodec(t *testing.T) { + ctr := defaultConnTestRunner + ctr.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var hstoreOID uint32 + err := conn.QueryRow(context.Background(), `select oid from pg_type where typname = 'hstore'`).Scan(&hstoreOID) + if err != nil { + t.Skipf("Skipping: cannot find hstore OID") + } + + conn.TypeMap().RegisterType(&pgtype.Type{Name: "hstore", OID: hstoreOID, Codec: pgtype.HstoreCodec{}}) + } + + fs := func(s string) *string { + return &s + } + + tests := []pgxtest.ValueRoundTripTest{ + { + map[string]string{}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{}), + }, + { + map[string]string{"foo": "", "bar": "", "baz": "123"}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{"foo": "", "bar": "", "baz": "123"}), + }, + { + map[string]string{"NULL": "bar"}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{"NULL": "bar"}), + }, + { + map[string]string{"bar": "NULL"}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{"bar": "NULL"}), + }, + { + map[string]string{"": "foo"}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{"": "foo"}), + }, + { + map[string]*string{}, + new(map[string]*string), + isExpectedEqMapStringPointerString(map[string]*string{}), + }, + { + map[string]*string{"foo": fs("bar"), "baq": fs("quz")}, + new(map[string]*string), + isExpectedEqMapStringPointerString(map[string]*string{"foo": fs("bar"), "baq": fs("quz")}), + }, + { + map[string]*string{"foo": nil, "baq": fs("quz")}, + new(map[string]*string), + isExpectedEqMapStringPointerString(map[string]*string{"foo": nil, "baq": fs("quz")}), + }, + {nil, new(*map[string]string), isExpectedEq((*map[string]string)(nil))}, + {nil, new(*map[string]*string), isExpectedEq((*map[string]*string)(nil))}, + {nil, new(*pgtype.Hstore), isExpectedEq((*pgtype.Hstore)(nil))}, + } + + specialStrings := []string{ + `"`, + `'`, + `\`, + `\\`, + `=>`, + ` `, + `\ / / \\ => " ' " '`, + } + for _, s := range specialStrings { + // Special key values + + // at beginning + tests = append(tests, pgxtest.ValueRoundTripTest{ + map[string]string{s + "foo": "bar"}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{s + "foo": "bar"}), + }) + // in middle + tests = append(tests, pgxtest.ValueRoundTripTest{ + map[string]string{"foo" + s + "bar": "bar"}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{"foo" + s + "bar": "bar"}), + }) + // at end + tests = append(tests, pgxtest.ValueRoundTripTest{ + map[string]string{"foo" + s: "bar"}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{"foo" + s: "bar"}), + }) + // is key + tests = append(tests, pgxtest.ValueRoundTripTest{ + map[string]string{s: "bar"}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{s: "bar"}), + }) + + // Special value values + + // at beginning + tests = append(tests, pgxtest.ValueRoundTripTest{ + map[string]string{"foo": s + "bar"}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{"foo": s + "bar"}), + }) + // in middle + tests = append(tests, pgxtest.ValueRoundTripTest{ + map[string]string{"foo": "foo" + s + "bar"}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{"foo": "foo" + s + "bar"}), + }) + // at end + tests = append(tests, pgxtest.ValueRoundTripTest{ + map[string]string{"foo": "foo" + s}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{"foo": "foo" + s}), + }) + // is key + tests = append(tests, pgxtest.ValueRoundTripTest{ + map[string]string{"foo": s}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{"foo": s}), + }) + } + + pgxtest.RunValueRoundTripTests(context.Background(), t, ctr, pgxtest.KnownOIDQueryExecModes, "hstore", tests) +} diff --git a/pgtype/inet.go b/pgtype/inet.go new file mode 100644 index 00000000..a85646d7 --- /dev/null +++ b/pgtype/inet.go @@ -0,0 +1,200 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "errors" + "fmt" + "net/netip" +) + +// Network address family is dependent on server socket.h value for AF_INET. +// In practice, all platforms appear to have the same value. See +// src/include/utils/inet.h for more information. +const ( + defaultAFInet = 2 + defaultAFInet6 = 3 +) + +type NetipPrefixScanner interface { + ScanNetipPrefix(v netip.Prefix) error +} + +type NetipPrefixValuer interface { + NetipPrefixValue() (netip.Prefix, error) +} + +// InetCodec handles both inet and cidr PostgreSQL types. The preferred Go types are netip.Prefix and netip.Addr. If +// IsValid() is false then they are treated as SQL NULL. +type InetCodec struct{} + +func (InetCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (InetCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (InetCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(NetipPrefixValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return encodePlanInetCodecBinary{} + case TextFormatCode: + return encodePlanInetCodecText{} + } + + return nil +} + +type encodePlanInetCodecBinary struct{} + +func (encodePlanInetCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + prefix, err := value.(NetipPrefixValuer).NetipPrefixValue() + if err != nil { + return nil, err + } + + if !prefix.IsValid() { + return nil, nil + } + + var family byte + if prefix.Addr().Is4() { + family = defaultAFInet + } else { + family = defaultAFInet6 + } + + buf = append(buf, family) + + ones := prefix.Bits() + buf = append(buf, byte(ones)) + + // is_cidr is ignored on server + buf = append(buf, 0) + + if family == defaultAFInet { + buf = append(buf, byte(4)) + b := prefix.Addr().As4() + buf = append(buf, b[:]...) + } else { + buf = append(buf, byte(16)) + b := prefix.Addr().As16() + buf = append(buf, b[:]...) + } + + return buf, nil +} + +type encodePlanInetCodecText struct{} + +func (encodePlanInetCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { + prefix, err := value.(NetipPrefixValuer).NetipPrefixValue() + if err != nil { + return nil, err + } + + if !prefix.IsValid() { + return nil, nil + } + + return append(buf, prefix.String()...), nil +} + +func (InetCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case NetipPrefixScanner: + return scanPlanBinaryInetToNetipPrefixScanner{} + } + case TextFormatCode: + switch target.(type) { + case NetipPrefixScanner: + return scanPlanTextAnyToNetipPrefixScanner{} + } + } + + return nil +} + +func (c InetCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) +} + +func (c InetCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var prefix netip.Prefix + err := codecScan(c, m, oid, format, src, (*netipPrefixWrapper)(&prefix)) + if err != nil { + return nil, err + } + + if !prefix.IsValid() { + return nil, nil + } + + return prefix, nil +} + +type scanPlanBinaryInetToNetipPrefixScanner struct{} + +func (scanPlanBinaryInetToNetipPrefixScanner) Scan(src []byte, dst any) error { + scanner := (dst).(NetipPrefixScanner) + + if src == nil { + return scanner.ScanNetipPrefix(netip.Prefix{}) + } + + if len(src) != 8 && len(src) != 20 { + return fmt.Errorf("Received an invalid size for a inet: %d", len(src)) + } + + // ignore family + bits := src[1] + // ignore is_cidr + // ignore addressLength - implicit in length of message + + addr, ok := netip.AddrFromSlice(src[4:]) + if !ok { + return errors.New("netip.AddrFromSlice failed") + } + + return scanner.ScanNetipPrefix(netip.PrefixFrom(addr, int(bits))) +} + +type scanPlanTextAnyToNetipPrefixScanner struct{} + +func (scanPlanTextAnyToNetipPrefixScanner) Scan(src []byte, dst any) error { + scanner := (dst).(NetipPrefixScanner) + + if src == nil { + return scanner.ScanNetipPrefix(netip.Prefix{}) + } + + var prefix netip.Prefix + if bytes.IndexByte(src, '/') == -1 { + addr, err := netip.ParseAddr(string(src)) + if err != nil { + return err + } + prefix = netip.PrefixFrom(addr, addr.BitLen()) + } else { + var err error + prefix, err = netip.ParsePrefix(string(src)) + if err != nil { + return err + } + } + + return scanner.ScanNetipPrefix(prefix) +} diff --git a/pgtype/inet_test.go b/pgtype/inet_test.go new file mode 100644 index 00000000..f4b43daf --- /dev/null +++ b/pgtype/inet_test.go @@ -0,0 +1,99 @@ +package pgtype_test + +import ( + "context" + "net" + "net/netip" + "testing" + + "github.com/jackc/pgx/v5/pgxtest" +) + +func isExpectedEqIPNet(a any) func(any) bool { + return func(v any) bool { + ap := a.(*net.IPNet) + vp := v.(net.IPNet) + + return ap.IP.Equal(vp.IP) && ap.Mask.String() == vp.Mask.String() + } +} + +func TestInetTranscode(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "inet", []pgxtest.ValueRoundTripTest{ + {mustParseInet(t, "0.0.0.0/32"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "0.0.0.0/32"))}, + {mustParseInet(t, "127.0.0.1/8"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "127.0.0.1/8"))}, + {mustParseInet(t, "12.34.56.65/32"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "12.34.56.65/32"))}, + {mustParseInet(t, "192.168.1.16/24"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "192.168.1.16/24"))}, + {mustParseInet(t, "255.0.0.0/8"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "255.0.0.0/8"))}, + {mustParseInet(t, "255.255.255.255/32"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "255.255.255.255/32"))}, + {mustParseInet(t, "2607:f8b0:4009:80b::200e"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "2607:f8b0:4009:80b::200e"))}, + {mustParseInet(t, "::1/64"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "::1/64"))}, + {mustParseInet(t, "::/0"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "::/0"))}, + {mustParseInet(t, "::1/128"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "::1/128"))}, + {mustParseInet(t, "2607:f8b0:4009:80b::200e/64"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "2607:f8b0:4009:80b::200e/64"))}, + + {mustParseInet(t, "0.0.0.0/32"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("0.0.0.0/32"))}, + {mustParseInet(t, "127.0.0.1/8"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("127.0.0.1/8"))}, + {mustParseInet(t, "12.34.56.65/32"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("12.34.56.65/32"))}, + {mustParseInet(t, "192.168.1.16/24"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("192.168.1.16/24"))}, + {mustParseInet(t, "255.0.0.0/8"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("255.0.0.0/8"))}, + {mustParseInet(t, "255.255.255.255/32"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("255.255.255.255/32"))}, + {mustParseInet(t, "2607:f8b0:4009:80b::200e"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("2607:f8b0:4009:80b::200e/128"))}, + {mustParseInet(t, "::1/64"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("::1/64"))}, + {mustParseInet(t, "::/0"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("::/0"))}, + {mustParseInet(t, "::1/128"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("::1/128"))}, + {mustParseInet(t, "2607:f8b0:4009:80b::200e/64"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("2607:f8b0:4009:80b::200e/64"))}, + + {netip.MustParsePrefix("0.0.0.0/32"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("0.0.0.0/32"))}, + {netip.MustParsePrefix("127.0.0.1/8"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("127.0.0.1/8"))}, + {netip.MustParsePrefix("12.34.56.65/32"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("12.34.56.65/32"))}, + {netip.MustParsePrefix("192.168.1.16/24"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("192.168.1.16/24"))}, + {netip.MustParsePrefix("255.0.0.0/8"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("255.0.0.0/8"))}, + {netip.MustParsePrefix("255.255.255.255/32"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("255.255.255.255/32"))}, + {netip.MustParsePrefix("::1/64"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("::1/64"))}, + {netip.MustParsePrefix("::/0"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("::/0"))}, + {netip.MustParsePrefix("::1/128"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("::1/128"))}, + {netip.MustParsePrefix("2607:f8b0:4009:80b::200e/64"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("2607:f8b0:4009:80b::200e/64"))}, + + {netip.MustParseAddr("0.0.0.0"), new(netip.Addr), isExpectedEq(netip.MustParseAddr("0.0.0.0"))}, + {netip.MustParseAddr("127.0.0.1"), new(netip.Addr), isExpectedEq(netip.MustParseAddr("127.0.0.1"))}, + {netip.MustParseAddr("12.34.56.65"), new(netip.Addr), isExpectedEq(netip.MustParseAddr("12.34.56.65"))}, + {netip.MustParseAddr("192.168.1.16"), new(netip.Addr), isExpectedEq(netip.MustParseAddr("192.168.1.16"))}, + {netip.MustParseAddr("255.0.0.0"), new(netip.Addr), isExpectedEq(netip.MustParseAddr("255.0.0.0"))}, + {netip.MustParseAddr("255.255.255.255"), new(netip.Addr), isExpectedEq(netip.MustParseAddr("255.255.255.255"))}, + {netip.MustParseAddr("2607:f8b0:4009:80b::200e"), new(netip.Addr), isExpectedEq(netip.MustParseAddr("2607:f8b0:4009:80b::200e"))}, + {netip.MustParseAddr("::1"), new(netip.Addr), isExpectedEq(netip.MustParseAddr("::1"))}, + {netip.MustParseAddr("::"), new(netip.Addr), isExpectedEq(netip.MustParseAddr("::"))}, + {netip.MustParseAddr("2607:f8b0:4009:80b::200e"), new(netip.Addr), isExpectedEq(netip.MustParseAddr("2607:f8b0:4009:80b::200e"))}, + + {nil, new(netip.Prefix), isExpectedEq(netip.Prefix{})}, + }) +} + +func TestCidrTranscode(t *testing.T) { + skipCockroachDB(t, "Server does not support cidr type (see https://github.com/cockroachdb/cockroach/issues/18846)") + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "cidr", []pgxtest.ValueRoundTripTest{ + {mustParseInet(t, "0.0.0.0/32"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "0.0.0.0/32"))}, + {mustParseInet(t, "127.0.0.1/32"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "127.0.0.1/32"))}, + {mustParseInet(t, "12.34.56.0/32"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "12.34.56.0/32"))}, + {mustParseInet(t, "192.168.1.0/24"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "192.168.1.0/24"))}, + {mustParseInet(t, "255.0.0.0/8"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "255.0.0.0/8"))}, + {mustParseInet(t, "::/128"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "::/128"))}, + {mustParseInet(t, "::/0"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "::/0"))}, + {mustParseInet(t, "::1/128"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "::1/128"))}, + {mustParseInet(t, "2607:f8b0:4009:80b::200e/128"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "2607:f8b0:4009:80b::200e/128"))}, + + {netip.MustParsePrefix("0.0.0.0/32"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("0.0.0.0/32"))}, + {netip.MustParsePrefix("127.0.0.1/32"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("127.0.0.1/32"))}, + {netip.MustParsePrefix("12.34.56.0/32"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("12.34.56.0/32"))}, + {netip.MustParsePrefix("192.168.1.0/24"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("192.168.1.0/24"))}, + {netip.MustParsePrefix("255.0.0.0/8"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("255.0.0.0/8"))}, + {netip.MustParsePrefix("::/128"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("::/128"))}, + {netip.MustParsePrefix("::/0"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("::/0"))}, + {netip.MustParsePrefix("::1/128"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("::1/128"))}, + {netip.MustParsePrefix("2607:f8b0:4009:80b::200e/128"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("2607:f8b0:4009:80b::200e/128"))}, + + {nil, new(netip.Prefix), isExpectedEq(netip.Prefix{})}, + }) +} diff --git a/pgtype/int.go b/pgtype/int.go new file mode 100644 index 00000000..1cda0ba3 --- /dev/null +++ b/pgtype/int.go @@ -0,0 +1,1980 @@ +// Do not edit. Generated from pgtype/int.go.erb +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "encoding/json" + "fmt" + "math" + "strconv" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type Int64Scanner interface { + ScanInt64(Int8) error +} + +type Int64Valuer interface { + Int64Value() (Int8, error) +} + +type Int2 struct { + Int16 int16 + Valid bool +} + +// ScanInt64 implements the Int64Scanner interface. +func (dst *Int2) ScanInt64(n Int8) error { + if !n.Valid { + *dst = Int2{} + return nil + } + + if n.Int64 < math.MinInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", n.Int64) + } + if n.Int64 > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", n.Int64) + } + *dst = Int2{Int16: int16(n.Int64), Valid: true} + + return nil +} + +func (n Int2) Int64Value() (Int8, error) { + return Int8{Int64: int64(n.Int16), Valid: n.Valid}, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int2) Scan(src any) error { + if src == nil { + *dst = Int2{} + return nil + } + + var n int64 + + switch src := src.(type) { + case int64: + n = src + case string: + var err error + n, err = strconv.ParseInt(src, 10, 16) + if err != nil { + return err + } + case []byte: + var err error + n, err = strconv.ParseInt(string(src), 10, 16) + if err != nil { + return err + } + default: + return fmt.Errorf("cannot scan %T", src) + } + + if n < math.MinInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", n) + } + if n > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", n) + } + *dst = Int2{Int16: int16(n), Valid: true} + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int2) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + return int64(src.Int16), nil +} + +func (src Int2) MarshalJSON() ([]byte, error) { + if !src.Valid { + return []byte("null"), nil + } + return []byte(strconv.FormatInt(int64(src.Int16), 10)), nil +} + +func (dst *Int2) UnmarshalJSON(b []byte) error { + var n *int16 + err := json.Unmarshal(b, &n) + if err != nil { + return err + } + + if n == nil { + *dst = Int2{} + } else { + *dst = Int2{Int16: *n, Valid: true} + } + + return nil +} + +type Int2Codec struct{} + +func (Int2Codec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (Int2Codec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (Int2Codec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch format { + case BinaryFormatCode: + switch value.(type) { + case int16: + return encodePlanInt2CodecBinaryInt16{} + case Int64Valuer: + return encodePlanInt2CodecBinaryInt64Valuer{} + } + case TextFormatCode: + switch value.(type) { + case int16: + return encodePlanInt2CodecTextInt16{} + case Int64Valuer: + return encodePlanInt2CodecTextInt64Valuer{} + } + } + + return nil +} + +type encodePlanInt2CodecBinaryInt16 struct{} + +func (encodePlanInt2CodecBinaryInt16) Encode(value any, buf []byte) (newBuf []byte, err error) { + n := value.(int16) + return pgio.AppendInt16(buf, int16(n)), nil +} + +type encodePlanInt2CodecTextInt16 struct{} + +func (encodePlanInt2CodecTextInt16) Encode(value any, buf []byte) (newBuf []byte, err error) { + n := value.(int16) + return append(buf, strconv.FormatInt(int64(n), 10)...), nil +} + +type encodePlanInt2CodecBinaryInt64Valuer struct{} + +func (encodePlanInt2CodecBinaryInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + if n.Int64 > math.MaxInt16 { + return nil, fmt.Errorf("%d is greater than maximum value for int2", n.Int64) + } + if n.Int64 < math.MinInt16 { + return nil, fmt.Errorf("%d is less than minimum value for int2", n.Int64) + } + + return pgio.AppendInt16(buf, int16(n.Int64)), nil +} + +type encodePlanInt2CodecTextInt64Valuer struct{} + +func (encodePlanInt2CodecTextInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + if n.Int64 > math.MaxInt16 { + return nil, fmt.Errorf("%d is greater than maximum value for int2", n.Int64) + } + if n.Int64 < math.MinInt16 { + return nil, fmt.Errorf("%d is less than minimum value for int2", n.Int64) + } + + return append(buf, strconv.FormatInt(n.Int64, 10)...), nil +} + +func (Int2Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case *int8: + return scanPlanBinaryInt2ToInt8{} + case *int16: + return scanPlanBinaryInt2ToInt16{} + case *int32: + return scanPlanBinaryInt2ToInt32{} + case *int64: + return scanPlanBinaryInt2ToInt64{} + case *int: + return scanPlanBinaryInt2ToInt{} + case *uint8: + return scanPlanBinaryInt2ToUint8{} + case *uint16: + return scanPlanBinaryInt2ToUint16{} + case *uint32: + return scanPlanBinaryInt2ToUint32{} + case *uint64: + return scanPlanBinaryInt2ToUint64{} + case *uint: + return scanPlanBinaryInt2ToUint{} + case Int64Scanner: + return scanPlanBinaryInt2ToInt64Scanner{} + case TextScanner: + return scanPlanBinaryInt2ToTextScanner{} + } + case TextFormatCode: + switch target.(type) { + case *int8: + return scanPlanTextAnyToInt8{} + case *int16: + return scanPlanTextAnyToInt16{} + case *int32: + return scanPlanTextAnyToInt32{} + case *int64: + return scanPlanTextAnyToInt64{} + case *int: + return scanPlanTextAnyToInt{} + case *uint8: + return scanPlanTextAnyToUint8{} + case *uint16: + return scanPlanTextAnyToUint16{} + case *uint32: + return scanPlanTextAnyToUint32{} + case *uint64: + return scanPlanTextAnyToUint64{} + case *uint: + return scanPlanTextAnyToUint{} + case Int64Scanner: + return scanPlanTextAnyToInt64Scanner{} + } + } + + return nil +} + +func (c Int2Codec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + var n int64 + err := codecScan(c, m, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil +} + +func (c Int2Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var n int16 + err := codecScan(c, m, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil +} + +type scanPlanBinaryInt2ToInt8 struct{} + +func (scanPlanBinaryInt2ToInt8) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 2 { + return fmt.Errorf("invalid length for int2: %v", len(src)) + } + + p, ok := (dst).(*int8) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int16(binary.BigEndian.Uint16(src)) + if n < math.MinInt8 { + return fmt.Errorf("%d is less than minimum value for int8", n) + } else if n > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for int8", n) + } + + *p = int8(n) + + return nil +} + +type scanPlanBinaryInt2ToUint8 struct{} + +func (scanPlanBinaryInt2ToUint8) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 2 { + return fmt.Errorf("invalid length for uint2: %v", len(src)) + } + + p, ok := (dst).(*uint8) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int16(binary.BigEndian.Uint16(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint8", n) + } + + if n > math.MaxUint8 { + return fmt.Errorf("%d is greater than maximum value for uint8", n) + } + + *p = uint8(n) + + return nil +} + +type scanPlanBinaryInt2ToInt16 struct{} + +func (scanPlanBinaryInt2ToInt16) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 2 { + return fmt.Errorf("invalid length for int2: %v", len(src)) + } + + p, ok := (dst).(*int16) + if !ok { + return ErrScanTargetTypeChanged + } + + *p = int16(binary.BigEndian.Uint16(src)) + + return nil +} + +type scanPlanBinaryInt2ToUint16 struct{} + +func (scanPlanBinaryInt2ToUint16) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 2 { + return fmt.Errorf("invalid length for uint2: %v", len(src)) + } + + p, ok := (dst).(*uint16) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int16(binary.BigEndian.Uint16(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint16", n) + } + + *p = uint16(n) + + return nil +} + +type scanPlanBinaryInt2ToInt32 struct{} + +func (scanPlanBinaryInt2ToInt32) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 2 { + return fmt.Errorf("invalid length for int2: %v", len(src)) + } + + p, ok := (dst).(*int32) + if !ok { + return ErrScanTargetTypeChanged + } + + *p = int32(int16(binary.BigEndian.Uint16(src))) + + return nil +} + +type scanPlanBinaryInt2ToUint32 struct{} + +func (scanPlanBinaryInt2ToUint32) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 2 { + return fmt.Errorf("invalid length for uint2: %v", len(src)) + } + + p, ok := (dst).(*uint32) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int16(binary.BigEndian.Uint16(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint32", n) + } + + *p = uint32(n) + + return nil +} + +type scanPlanBinaryInt2ToInt64 struct{} + +func (scanPlanBinaryInt2ToInt64) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 2 { + return fmt.Errorf("invalid length for int2: %v", len(src)) + } + + p, ok := (dst).(*int64) + if !ok { + return ErrScanTargetTypeChanged + } + + *p = int64(int16(binary.BigEndian.Uint16(src))) + + return nil +} + +type scanPlanBinaryInt2ToUint64 struct{} + +func (scanPlanBinaryInt2ToUint64) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 2 { + return fmt.Errorf("invalid length for uint2: %v", len(src)) + } + + p, ok := (dst).(*uint64) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int16(binary.BigEndian.Uint16(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint64", n) + } + + *p = uint64(n) + + return nil +} + +type scanPlanBinaryInt2ToInt struct{} + +func (scanPlanBinaryInt2ToInt) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 2 { + return fmt.Errorf("invalid length for int2: %v", len(src)) + } + + p, ok := (dst).(*int) + if !ok { + return ErrScanTargetTypeChanged + } + + *p = int(int16(binary.BigEndian.Uint16(src))) + + return nil +} + +type scanPlanBinaryInt2ToUint struct{} + +func (scanPlanBinaryInt2ToUint) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 2 { + return fmt.Errorf("invalid length for uint2: %v", len(src)) + } + + p, ok := (dst).(*uint) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int64(int16(binary.BigEndian.Uint16(src))) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint", n) + } + + *p = uint(n) + + return nil +} + +type scanPlanBinaryInt2ToInt64Scanner struct{} + +func (scanPlanBinaryInt2ToInt64Scanner) Scan(src []byte, dst any) error { + s, ok := (dst).(Int64Scanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanInt64(Int8{}) + } + + if len(src) != 2 { + return fmt.Errorf("invalid length for int2: %v", len(src)) + } + + n := int64(int16(binary.BigEndian.Uint16(src))) + + return s.ScanInt64(Int8{Int64: n, Valid: true}) +} + +type scanPlanBinaryInt2ToTextScanner struct{} + +func (scanPlanBinaryInt2ToTextScanner) Scan(src []byte, dst any) error { + s, ok := (dst).(TextScanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanText(Text{}) + } + + if len(src) != 2 { + return fmt.Errorf("invalid length for int2: %v", len(src)) + } + + n := int64(int16(binary.BigEndian.Uint16(src))) + + return s.ScanText(Text{String: strconv.FormatInt(n, 10), Valid: true}) +} + +type Int4 struct { + Int32 int32 + Valid bool +} + +// ScanInt64 implements the Int64Scanner interface. +func (dst *Int4) ScanInt64(n Int8) error { + if !n.Valid { + *dst = Int4{} + return nil + } + + if n.Int64 < math.MinInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", n.Int64) + } + if n.Int64 > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", n.Int64) + } + *dst = Int4{Int32: int32(n.Int64), Valid: true} + + return nil +} + +func (n Int4) Int64Value() (Int8, error) { + return Int8{Int64: int64(n.Int32), Valid: n.Valid}, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int4) Scan(src any) error { + if src == nil { + *dst = Int4{} + return nil + } + + var n int64 + + switch src := src.(type) { + case int64: + n = src + case string: + var err error + n, err = strconv.ParseInt(src, 10, 32) + if err != nil { + return err + } + case []byte: + var err error + n, err = strconv.ParseInt(string(src), 10, 32) + if err != nil { + return err + } + default: + return fmt.Errorf("cannot scan %T", src) + } + + if n < math.MinInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", n) + } + if n > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", n) + } + *dst = Int4{Int32: int32(n), Valid: true} + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int4) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + return int64(src.Int32), nil +} + +func (src Int4) MarshalJSON() ([]byte, error) { + if !src.Valid { + return []byte("null"), nil + } + return []byte(strconv.FormatInt(int64(src.Int32), 10)), nil +} + +func (dst *Int4) UnmarshalJSON(b []byte) error { + var n *int32 + err := json.Unmarshal(b, &n) + if err != nil { + return err + } + + if n == nil { + *dst = Int4{} + } else { + *dst = Int4{Int32: *n, Valid: true} + } + + return nil +} + +type Int4Codec struct{} + +func (Int4Codec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (Int4Codec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (Int4Codec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch format { + case BinaryFormatCode: + switch value.(type) { + case int32: + return encodePlanInt4CodecBinaryInt32{} + case Int64Valuer: + return encodePlanInt4CodecBinaryInt64Valuer{} + } + case TextFormatCode: + switch value.(type) { + case int32: + return encodePlanInt4CodecTextInt32{} + case Int64Valuer: + return encodePlanInt4CodecTextInt64Valuer{} + } + } + + return nil +} + +type encodePlanInt4CodecBinaryInt32 struct{} + +func (encodePlanInt4CodecBinaryInt32) Encode(value any, buf []byte) (newBuf []byte, err error) { + n := value.(int32) + return pgio.AppendInt32(buf, int32(n)), nil +} + +type encodePlanInt4CodecTextInt32 struct{} + +func (encodePlanInt4CodecTextInt32) Encode(value any, buf []byte) (newBuf []byte, err error) { + n := value.(int32) + return append(buf, strconv.FormatInt(int64(n), 10)...), nil +} + +type encodePlanInt4CodecBinaryInt64Valuer struct{} + +func (encodePlanInt4CodecBinaryInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + if n.Int64 > math.MaxInt32 { + return nil, fmt.Errorf("%d is greater than maximum value for int4", n.Int64) + } + if n.Int64 < math.MinInt32 { + return nil, fmt.Errorf("%d is less than minimum value for int4", n.Int64) + } + + return pgio.AppendInt32(buf, int32(n.Int64)), nil +} + +type encodePlanInt4CodecTextInt64Valuer struct{} + +func (encodePlanInt4CodecTextInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + if n.Int64 > math.MaxInt32 { + return nil, fmt.Errorf("%d is greater than maximum value for int4", n.Int64) + } + if n.Int64 < math.MinInt32 { + return nil, fmt.Errorf("%d is less than minimum value for int4", n.Int64) + } + + return append(buf, strconv.FormatInt(n.Int64, 10)...), nil +} + +func (Int4Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case *int8: + return scanPlanBinaryInt4ToInt8{} + case *int16: + return scanPlanBinaryInt4ToInt16{} + case *int32: + return scanPlanBinaryInt4ToInt32{} + case *int64: + return scanPlanBinaryInt4ToInt64{} + case *int: + return scanPlanBinaryInt4ToInt{} + case *uint8: + return scanPlanBinaryInt4ToUint8{} + case *uint16: + return scanPlanBinaryInt4ToUint16{} + case *uint32: + return scanPlanBinaryInt4ToUint32{} + case *uint64: + return scanPlanBinaryInt4ToUint64{} + case *uint: + return scanPlanBinaryInt4ToUint{} + case Int64Scanner: + return scanPlanBinaryInt4ToInt64Scanner{} + case TextScanner: + return scanPlanBinaryInt4ToTextScanner{} + } + case TextFormatCode: + switch target.(type) { + case *int8: + return scanPlanTextAnyToInt8{} + case *int16: + return scanPlanTextAnyToInt16{} + case *int32: + return scanPlanTextAnyToInt32{} + case *int64: + return scanPlanTextAnyToInt64{} + case *int: + return scanPlanTextAnyToInt{} + case *uint8: + return scanPlanTextAnyToUint8{} + case *uint16: + return scanPlanTextAnyToUint16{} + case *uint32: + return scanPlanTextAnyToUint32{} + case *uint64: + return scanPlanTextAnyToUint64{} + case *uint: + return scanPlanTextAnyToUint{} + case Int64Scanner: + return scanPlanTextAnyToInt64Scanner{} + } + } + + return nil +} + +func (c Int4Codec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + var n int64 + err := codecScan(c, m, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil +} + +func (c Int4Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var n int32 + err := codecScan(c, m, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil +} + +type scanPlanBinaryInt4ToInt8 struct{} + +func (scanPlanBinaryInt4ToInt8) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for int4: %v", len(src)) + } + + p, ok := (dst).(*int8) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int32(binary.BigEndian.Uint32(src)) + if n < math.MinInt8 { + return fmt.Errorf("%d is less than minimum value for int8", n) + } else if n > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for int8", n) + } + + *p = int8(n) + + return nil +} + +type scanPlanBinaryInt4ToUint8 struct{} + +func (scanPlanBinaryInt4ToUint8) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for uint4: %v", len(src)) + } + + p, ok := (dst).(*uint8) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int32(binary.BigEndian.Uint32(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint8", n) + } + + if n > math.MaxUint8 { + return fmt.Errorf("%d is greater than maximum value for uint8", n) + } + + *p = uint8(n) + + return nil +} + +type scanPlanBinaryInt4ToInt16 struct{} + +func (scanPlanBinaryInt4ToInt16) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for int4: %v", len(src)) + } + + p, ok := (dst).(*int16) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int32(binary.BigEndian.Uint32(src)) + if n < math.MinInt16 { + return fmt.Errorf("%d is less than minimum value for int16", n) + } else if n > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for int16", n) + } + + *p = int16(n) + + return nil +} + +type scanPlanBinaryInt4ToUint16 struct{} + +func (scanPlanBinaryInt4ToUint16) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for uint4: %v", len(src)) + } + + p, ok := (dst).(*uint16) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int32(binary.BigEndian.Uint32(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint16", n) + } + + if n > math.MaxUint16 { + return fmt.Errorf("%d is greater than maximum value for uint16", n) + } + + *p = uint16(n) + + return nil +} + +type scanPlanBinaryInt4ToInt32 struct{} + +func (scanPlanBinaryInt4ToInt32) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for int4: %v", len(src)) + } + + p, ok := (dst).(*int32) + if !ok { + return ErrScanTargetTypeChanged + } + + *p = int32(binary.BigEndian.Uint32(src)) + + return nil +} + +type scanPlanBinaryInt4ToUint32 struct{} + +func (scanPlanBinaryInt4ToUint32) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for uint4: %v", len(src)) + } + + p, ok := (dst).(*uint32) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int32(binary.BigEndian.Uint32(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint32", n) + } + + *p = uint32(n) + + return nil +} + +type scanPlanBinaryInt4ToInt64 struct{} + +func (scanPlanBinaryInt4ToInt64) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for int4: %v", len(src)) + } + + p, ok := (dst).(*int64) + if !ok { + return ErrScanTargetTypeChanged + } + + *p = int64(int32(binary.BigEndian.Uint32(src))) + + return nil +} + +type scanPlanBinaryInt4ToUint64 struct{} + +func (scanPlanBinaryInt4ToUint64) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for uint4: %v", len(src)) + } + + p, ok := (dst).(*uint64) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int32(binary.BigEndian.Uint32(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint64", n) + } + + *p = uint64(n) + + return nil +} + +type scanPlanBinaryInt4ToInt struct{} + +func (scanPlanBinaryInt4ToInt) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for int4: %v", len(src)) + } + + p, ok := (dst).(*int) + if !ok { + return ErrScanTargetTypeChanged + } + + *p = int(int32(binary.BigEndian.Uint32(src))) + + return nil +} + +type scanPlanBinaryInt4ToUint struct{} + +func (scanPlanBinaryInt4ToUint) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for uint4: %v", len(src)) + } + + p, ok := (dst).(*uint) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int64(int32(binary.BigEndian.Uint32(src))) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint", n) + } + + *p = uint(n) + + return nil +} + +type scanPlanBinaryInt4ToInt64Scanner struct{} + +func (scanPlanBinaryInt4ToInt64Scanner) Scan(src []byte, dst any) error { + s, ok := (dst).(Int64Scanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanInt64(Int8{}) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for int4: %v", len(src)) + } + + n := int64(int32(binary.BigEndian.Uint32(src))) + + return s.ScanInt64(Int8{Int64: n, Valid: true}) +} + +type scanPlanBinaryInt4ToTextScanner struct{} + +func (scanPlanBinaryInt4ToTextScanner) Scan(src []byte, dst any) error { + s, ok := (dst).(TextScanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanText(Text{}) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for int4: %v", len(src)) + } + + n := int64(int32(binary.BigEndian.Uint32(src))) + + return s.ScanText(Text{String: strconv.FormatInt(n, 10), Valid: true}) +} + +type Int8 struct { + Int64 int64 + Valid bool +} + +// ScanInt64 implements the Int64Scanner interface. +func (dst *Int8) ScanInt64(n Int8) error { + if !n.Valid { + *dst = Int8{} + return nil + } + + if n.Int64 < math.MinInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", n.Int64) + } + if n.Int64 > math.MaxInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", n.Int64) + } + *dst = Int8{Int64: int64(n.Int64), Valid: true} + + return nil +} + +func (n Int8) Int64Value() (Int8, error) { + return Int8{Int64: int64(n.Int64), Valid: n.Valid}, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int8) Scan(src any) error { + if src == nil { + *dst = Int8{} + return nil + } + + var n int64 + + switch src := src.(type) { + case int64: + n = src + case string: + var err error + n, err = strconv.ParseInt(src, 10, 64) + if err != nil { + return err + } + case []byte: + var err error + n, err = strconv.ParseInt(string(src), 10, 64) + if err != nil { + return err + } + default: + return fmt.Errorf("cannot scan %T", src) + } + + if n < math.MinInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", n) + } + if n > math.MaxInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", n) + } + *dst = Int8{Int64: int64(n), Valid: true} + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int8) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + return int64(src.Int64), nil +} + +func (src Int8) MarshalJSON() ([]byte, error) { + if !src.Valid { + return []byte("null"), nil + } + return []byte(strconv.FormatInt(int64(src.Int64), 10)), nil +} + +func (dst *Int8) UnmarshalJSON(b []byte) error { + var n *int64 + err := json.Unmarshal(b, &n) + if err != nil { + return err + } + + if n == nil { + *dst = Int8{} + } else { + *dst = Int8{Int64: *n, Valid: true} + } + + return nil +} + +type Int8Codec struct{} + +func (Int8Codec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (Int8Codec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (Int8Codec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch format { + case BinaryFormatCode: + switch value.(type) { + case int64: + return encodePlanInt8CodecBinaryInt64{} + case Int64Valuer: + return encodePlanInt8CodecBinaryInt64Valuer{} + } + case TextFormatCode: + switch value.(type) { + case int64: + return encodePlanInt8CodecTextInt64{} + case Int64Valuer: + return encodePlanInt8CodecTextInt64Valuer{} + } + } + + return nil +} + +type encodePlanInt8CodecBinaryInt64 struct{} + +func (encodePlanInt8CodecBinaryInt64) Encode(value any, buf []byte) (newBuf []byte, err error) { + n := value.(int64) + return pgio.AppendInt64(buf, int64(n)), nil +} + +type encodePlanInt8CodecTextInt64 struct{} + +func (encodePlanInt8CodecTextInt64) Encode(value any, buf []byte) (newBuf []byte, err error) { + n := value.(int64) + return append(buf, strconv.FormatInt(int64(n), 10)...), nil +} + +type encodePlanInt8CodecBinaryInt64Valuer struct{} + +func (encodePlanInt8CodecBinaryInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + if n.Int64 > math.MaxInt64 { + return nil, fmt.Errorf("%d is greater than maximum value for int8", n.Int64) + } + if n.Int64 < math.MinInt64 { + return nil, fmt.Errorf("%d is less than minimum value for int8", n.Int64) + } + + return pgio.AppendInt64(buf, int64(n.Int64)), nil +} + +type encodePlanInt8CodecTextInt64Valuer struct{} + +func (encodePlanInt8CodecTextInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + if n.Int64 > math.MaxInt64 { + return nil, fmt.Errorf("%d is greater than maximum value for int8", n.Int64) + } + if n.Int64 < math.MinInt64 { + return nil, fmt.Errorf("%d is less than minimum value for int8", n.Int64) + } + + return append(buf, strconv.FormatInt(n.Int64, 10)...), nil +} + +func (Int8Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case *int8: + return scanPlanBinaryInt8ToInt8{} + case *int16: + return scanPlanBinaryInt8ToInt16{} + case *int32: + return scanPlanBinaryInt8ToInt32{} + case *int64: + return scanPlanBinaryInt8ToInt64{} + case *int: + return scanPlanBinaryInt8ToInt{} + case *uint8: + return scanPlanBinaryInt8ToUint8{} + case *uint16: + return scanPlanBinaryInt8ToUint16{} + case *uint32: + return scanPlanBinaryInt8ToUint32{} + case *uint64: + return scanPlanBinaryInt8ToUint64{} + case *uint: + return scanPlanBinaryInt8ToUint{} + case Int64Scanner: + return scanPlanBinaryInt8ToInt64Scanner{} + case TextScanner: + return scanPlanBinaryInt8ToTextScanner{} + } + case TextFormatCode: + switch target.(type) { + case *int8: + return scanPlanTextAnyToInt8{} + case *int16: + return scanPlanTextAnyToInt16{} + case *int32: + return scanPlanTextAnyToInt32{} + case *int64: + return scanPlanTextAnyToInt64{} + case *int: + return scanPlanTextAnyToInt{} + case *uint8: + return scanPlanTextAnyToUint8{} + case *uint16: + return scanPlanTextAnyToUint16{} + case *uint32: + return scanPlanTextAnyToUint32{} + case *uint64: + return scanPlanTextAnyToUint64{} + case *uint: + return scanPlanTextAnyToUint{} + case Int64Scanner: + return scanPlanTextAnyToInt64Scanner{} + } + } + + return nil +} + +func (c Int8Codec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + var n int64 + err := codecScan(c, m, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil +} + +func (c Int8Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var n int64 + err := codecScan(c, m, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil +} + +type scanPlanBinaryInt8ToInt8 struct{} + +func (scanPlanBinaryInt8ToInt8) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for int8: %v", len(src)) + } + + p, ok := (dst).(*int8) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int64(binary.BigEndian.Uint64(src)) + if n < math.MinInt8 { + return fmt.Errorf("%d is less than minimum value for int8", n) + } else if n > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for int8", n) + } + + *p = int8(n) + + return nil +} + +type scanPlanBinaryInt8ToUint8 struct{} + +func (scanPlanBinaryInt8ToUint8) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for uint8: %v", len(src)) + } + + p, ok := (dst).(*uint8) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int64(binary.BigEndian.Uint64(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint8", n) + } + + if n > math.MaxUint8 { + return fmt.Errorf("%d is greater than maximum value for uint8", n) + } + + *p = uint8(n) + + return nil +} + +type scanPlanBinaryInt8ToInt16 struct{} + +func (scanPlanBinaryInt8ToInt16) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for int8: %v", len(src)) + } + + p, ok := (dst).(*int16) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int64(binary.BigEndian.Uint64(src)) + if n < math.MinInt16 { + return fmt.Errorf("%d is less than minimum value for int16", n) + } else if n > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for int16", n) + } + + *p = int16(n) + + return nil +} + +type scanPlanBinaryInt8ToUint16 struct{} + +func (scanPlanBinaryInt8ToUint16) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for uint8: %v", len(src)) + } + + p, ok := (dst).(*uint16) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int64(binary.BigEndian.Uint64(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint16", n) + } + + if n > math.MaxUint16 { + return fmt.Errorf("%d is greater than maximum value for uint16", n) + } + + *p = uint16(n) + + return nil +} + +type scanPlanBinaryInt8ToInt32 struct{} + +func (scanPlanBinaryInt8ToInt32) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for int8: %v", len(src)) + } + + p, ok := (dst).(*int32) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int64(binary.BigEndian.Uint64(src)) + if n < math.MinInt32 { + return fmt.Errorf("%d is less than minimum value for int32", n) + } else if n > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for int32", n) + } + + *p = int32(n) + + return nil +} + +type scanPlanBinaryInt8ToUint32 struct{} + +func (scanPlanBinaryInt8ToUint32) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for uint8: %v", len(src)) + } + + p, ok := (dst).(*uint32) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int64(binary.BigEndian.Uint64(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint32", n) + } + + if n > math.MaxUint32 { + return fmt.Errorf("%d is greater than maximum value for uint32", n) + } + + *p = uint32(n) + + return nil +} + +type scanPlanBinaryInt8ToInt64 struct{} + +func (scanPlanBinaryInt8ToInt64) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for int8: %v", len(src)) + } + + p, ok := (dst).(*int64) + if !ok { + return ErrScanTargetTypeChanged + } + + *p = int64(binary.BigEndian.Uint64(src)) + + return nil +} + +type scanPlanBinaryInt8ToUint64 struct{} + +func (scanPlanBinaryInt8ToUint64) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for uint8: %v", len(src)) + } + + p, ok := (dst).(*uint64) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int64(binary.BigEndian.Uint64(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint64", n) + } + + *p = uint64(n) + + return nil +} + +type scanPlanBinaryInt8ToInt struct{} + +func (scanPlanBinaryInt8ToInt) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for int8: %v", len(src)) + } + + p, ok := (dst).(*int) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int64(binary.BigEndian.Uint64(src)) + if n < math.MinInt { + return fmt.Errorf("%d is less than minimum value for int", n) + } else if n > math.MaxInt { + return fmt.Errorf("%d is greater than maximum value for int", n) + } + + *p = int(n) + + return nil +} + +type scanPlanBinaryInt8ToUint struct{} + +func (scanPlanBinaryInt8ToUint) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for uint8: %v", len(src)) + } + + p, ok := (dst).(*uint) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int64(int64(binary.BigEndian.Uint64(src))) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint", n) + } + + if uint64(n) > math.MaxUint { + return fmt.Errorf("%d is greater than maximum value for uint", n) + } + + *p = uint(n) + + return nil +} + +type scanPlanBinaryInt8ToInt64Scanner struct{} + +func (scanPlanBinaryInt8ToInt64Scanner) Scan(src []byte, dst any) error { + s, ok := (dst).(Int64Scanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanInt64(Int8{}) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for int8: %v", len(src)) + } + + n := int64(int64(binary.BigEndian.Uint64(src))) + + return s.ScanInt64(Int8{Int64: n, Valid: true}) +} + +type scanPlanBinaryInt8ToTextScanner struct{} + +func (scanPlanBinaryInt8ToTextScanner) Scan(src []byte, dst any) error { + s, ok := (dst).(TextScanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanText(Text{}) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for int8: %v", len(src)) + } + + n := int64(int64(binary.BigEndian.Uint64(src))) + + return s.ScanText(Text{String: strconv.FormatInt(n, 10), Valid: true}) +} + +type scanPlanTextAnyToInt8 struct{} + +func (scanPlanTextAnyToInt8) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + p, ok := (dst).(*int8) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseInt(string(src), 10, 8) + if err != nil { + return err + } + + *p = int8(n) + return nil +} + +type scanPlanTextAnyToUint8 struct{} + +func (scanPlanTextAnyToUint8) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + p, ok := (dst).(*uint8) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseUint(string(src), 10, 8) + if err != nil { + return err + } + + *p = uint8(n) + return nil +} + +type scanPlanTextAnyToInt16 struct{} + +func (scanPlanTextAnyToInt16) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + p, ok := (dst).(*int16) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseInt(string(src), 10, 16) + if err != nil { + return err + } + + *p = int16(n) + return nil +} + +type scanPlanTextAnyToUint16 struct{} + +func (scanPlanTextAnyToUint16) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + p, ok := (dst).(*uint16) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseUint(string(src), 10, 16) + if err != nil { + return err + } + + *p = uint16(n) + return nil +} + +type scanPlanTextAnyToInt32 struct{} + +func (scanPlanTextAnyToInt32) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + p, ok := (dst).(*int32) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseInt(string(src), 10, 32) + if err != nil { + return err + } + + *p = int32(n) + return nil +} + +type scanPlanTextAnyToUint32 struct{} + +func (scanPlanTextAnyToUint32) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + p, ok := (dst).(*uint32) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseUint(string(src), 10, 32) + if err != nil { + return err + } + + *p = uint32(n) + return nil +} + +type scanPlanTextAnyToInt64 struct{} + +func (scanPlanTextAnyToInt64) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + p, ok := (dst).(*int64) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseInt(string(src), 10, 64) + if err != nil { + return err + } + + *p = int64(n) + return nil +} + +type scanPlanTextAnyToUint64 struct{} + +func (scanPlanTextAnyToUint64) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + p, ok := (dst).(*uint64) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseUint(string(src), 10, 64) + if err != nil { + return err + } + + *p = uint64(n) + return nil +} + +type scanPlanTextAnyToInt struct{} + +func (scanPlanTextAnyToInt) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + p, ok := (dst).(*int) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseInt(string(src), 10, 0) + if err != nil { + return err + } + + *p = int(n) + return nil +} + +type scanPlanTextAnyToUint struct{} + +func (scanPlanTextAnyToUint) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + p, ok := (dst).(*uint) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseUint(string(src), 10, 0) + if err != nil { + return err + } + + *p = uint(n) + return nil +} + +type scanPlanTextAnyToInt64Scanner struct{} + +func (scanPlanTextAnyToInt64Scanner) Scan(src []byte, dst any) error { + s, ok := (dst).(Int64Scanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanInt64(Int8{}) + } + + n, err := strconv.ParseInt(string(src), 10, 64) + if err != nil { + return err + } + + err = s.ScanInt64(Int8{Int64: n, Valid: true}) + if err != nil { + return err + } + + return nil +} diff --git a/pgtype/int.go.erb b/pgtype/int.go.erb new file mode 100644 index 00000000..572408e1 --- /dev/null +++ b/pgtype/int.go.erb @@ -0,0 +1,547 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type Int64Scanner interface { + ScanInt64(Int8) error +} + +type Int64Valuer interface { + Int64Value() (Int8, error) +} + + +<% [2, 4, 8].each do |pg_byte_size| %> +<% pg_bit_size = pg_byte_size * 8 %> +type Int<%= pg_byte_size %> struct { + Int<%= pg_bit_size %> int<%= pg_bit_size %> + Valid bool +} + +// ScanInt64 implements the Int64Scanner interface. +func (dst *Int<%= pg_byte_size %>) ScanInt64(n Int8) error { + if !n.Valid { + *dst = Int<%= pg_byte_size %>{} + return nil + } + + if n.Int64 < math.MinInt<%= pg_bit_size %> { + return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n.Int64) + } + if n.Int64 > math.MaxInt<%= pg_bit_size %> { + return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n.Int64) + } + *dst = Int<%= pg_byte_size %>{Int<%= pg_bit_size %>: int<%= pg_bit_size %>(n.Int64), Valid: true} + + return nil +} + +func (n Int<%= pg_byte_size %>) Int64Value() (Int8, error) { + return Int8{Int64: int64(n.Int<%= pg_bit_size %>), Valid: n.Valid}, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int<%= pg_byte_size %>) Scan(src any) error { + if src == nil { + *dst = Int<%= pg_byte_size %>{} + return nil + } + + var n int64 + + switch src := src.(type) { + case int64: + n = src + case string: + var err error + n, err = strconv.ParseInt(src, 10, <%= pg_bit_size %>) + if err != nil { + return err + } + case []byte: + var err error + n, err = strconv.ParseInt(string(src), 10, <%= pg_bit_size %>) + if err != nil { + return err + } + default: + return fmt.Errorf("cannot scan %T", src) + } + + if n < math.MinInt<%= pg_bit_size %> { + return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n) + } + if n > math.MaxInt<%= pg_bit_size %> { + return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n) + } + *dst = Int<%= pg_byte_size %>{Int<%= pg_bit_size %>: int<%= pg_bit_size %>(n), Valid: true} + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int<%= pg_byte_size %>) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + return int64(src.Int<%= pg_bit_size %>), nil +} + +func (src Int<%= pg_byte_size %>) MarshalJSON() ([]byte, error) { + if !src.Valid { + return []byte("null"), nil + } + return []byte(strconv.FormatInt(int64(src.Int<%= pg_bit_size %>), 10)), nil +} + +func (dst *Int<%= pg_byte_size %>) UnmarshalJSON(b []byte) error { + var n *int<%= pg_bit_size %> + err := json.Unmarshal(b, &n) + if err != nil { + return err + } + + if n == nil { + *dst = Int<%= pg_byte_size %>{} + } else { + *dst = Int<%= pg_byte_size %>{Int<%= pg_bit_size %>: *n, Valid: true} + } + + return nil +} + +type Int<%= pg_byte_size %>Codec struct{} + +func (Int<%= pg_byte_size %>Codec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (Int<%= pg_byte_size %>Codec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (Int<%= pg_byte_size %>Codec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch format { + case BinaryFormatCode: + switch value.(type) { + case int<%= pg_bit_size %>: + return encodePlanInt<%= pg_byte_size %>CodecBinaryInt<%= pg_bit_size %>{} + case Int64Valuer: + return encodePlanInt<%= pg_byte_size %>CodecBinaryInt64Valuer{} + } + case TextFormatCode: + switch value.(type) { + case int<%= pg_bit_size %>: + return encodePlanInt<%= pg_byte_size %>CodecTextInt<%= pg_bit_size %>{} + case Int64Valuer: + return encodePlanInt<%= pg_byte_size %>CodecTextInt64Valuer{} + } + } + + return nil +} + +type encodePlanInt<%= pg_byte_size %>CodecBinaryInt<%= pg_bit_size %> struct{} + +func (encodePlanInt<%= pg_byte_size %>CodecBinaryInt<%= pg_bit_size %>) Encode(value any, buf []byte) (newBuf []byte, err error) { + n := value.(int<%= pg_bit_size %>) + return pgio.AppendInt<%= pg_bit_size %>(buf, int<%= pg_bit_size %>(n)), nil +} + +type encodePlanInt<%= pg_byte_size %>CodecTextInt<%= pg_bit_size %> struct{} + +func (encodePlanInt<%= pg_byte_size %>CodecTextInt<%= pg_bit_size %>) Encode(value any, buf []byte) (newBuf []byte, err error) { + n := value.(int<%= pg_bit_size %>) + return append(buf, strconv.FormatInt(int64(n), 10)...), nil +} + +type encodePlanInt<%= pg_byte_size %>CodecBinaryInt64Valuer struct{} + +func (encodePlanInt<%= pg_byte_size %>CodecBinaryInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + if n.Int64 > math.MaxInt<%= pg_bit_size %> { + return nil, fmt.Errorf("%d is greater than maximum value for int<%= pg_byte_size %>", n.Int64) + } + if n.Int64 < math.MinInt<%= pg_bit_size %> { + return nil, fmt.Errorf("%d is less than minimum value for int<%= pg_byte_size %>", n.Int64) + } + + return pgio.AppendInt<%= pg_bit_size %>(buf, int<%= pg_bit_size %>(n.Int64)), nil +} + +type encodePlanInt<%= pg_byte_size %>CodecTextInt64Valuer struct{} + +func (encodePlanInt<%= pg_byte_size %>CodecTextInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + if n.Int64 > math.MaxInt<%= pg_bit_size %> { + return nil, fmt.Errorf("%d is greater than maximum value for int<%= pg_byte_size %>", n.Int64) + } + if n.Int64 < math.MinInt<%= pg_bit_size %> { + return nil, fmt.Errorf("%d is less than minimum value for int<%= pg_byte_size %>", n.Int64) + } + + return append(buf, strconv.FormatInt(n.Int64, 10)...), nil +} + +func (Int<%= pg_byte_size %>Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case *int8: + return scanPlanBinaryInt<%= pg_byte_size %>ToInt8{} + case *int16: + return scanPlanBinaryInt<%= pg_byte_size %>ToInt16{} + case *int32: + return scanPlanBinaryInt<%= pg_byte_size %>ToInt32{} + case *int64: + return scanPlanBinaryInt<%= pg_byte_size %>ToInt64{} + case *int: + return scanPlanBinaryInt<%= pg_byte_size %>ToInt{} + case *uint8: + return scanPlanBinaryInt<%= pg_byte_size %>ToUint8{} + case *uint16: + return scanPlanBinaryInt<%= pg_byte_size %>ToUint16{} + case *uint32: + return scanPlanBinaryInt<%= pg_byte_size %>ToUint32{} + case *uint64: + return scanPlanBinaryInt<%= pg_byte_size %>ToUint64{} + case *uint: + return scanPlanBinaryInt<%= pg_byte_size %>ToUint{} + case Int64Scanner: + return scanPlanBinaryInt<%= pg_byte_size %>ToInt64Scanner{} + case TextScanner: + return scanPlanBinaryInt<%= pg_byte_size %>ToTextScanner{} + } + case TextFormatCode: + switch target.(type) { + case *int8: + return scanPlanTextAnyToInt8{} + case *int16: + return scanPlanTextAnyToInt16{} + case *int32: + return scanPlanTextAnyToInt32{} + case *int64: + return scanPlanTextAnyToInt64{} + case *int: + return scanPlanTextAnyToInt{} + case *uint8: + return scanPlanTextAnyToUint8{} + case *uint16: + return scanPlanTextAnyToUint16{} + case *uint32: + return scanPlanTextAnyToUint32{} + case *uint64: + return scanPlanTextAnyToUint64{} + case *uint: + return scanPlanTextAnyToUint{} + case Int64Scanner: + return scanPlanTextAnyToInt64Scanner{} + } + } + + return nil +} + +func (c Int<%= pg_byte_size %>Codec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + var n int64 + err := codecScan(c, m, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil +} + +func (c Int<%= pg_byte_size %>Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var n int<%= pg_bit_size %> + err := codecScan(c, m, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil +} + +<%# PostgreSQL binary format integer to fixed size Go integers %> +<% [8, 16, 32, 64].each do |dst_bit_size| %> +type scanPlanBinaryInt<%= pg_byte_size %>ToInt<%= dst_bit_size %> struct{} + +func (scanPlanBinaryInt<%= pg_byte_size %>ToInt<%= dst_bit_size %>) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != <%= pg_byte_size %> { + return fmt.Errorf("invalid length for int<%= pg_byte_size %>: %v", len(src)) + } + + p, ok := (dst).(*int<%= dst_bit_size %>) + if !ok { + return ErrScanTargetTypeChanged + } + + <% if dst_bit_size < pg_bit_size %> + n := int<%= pg_bit_size %>(binary.BigEndian.Uint<%= pg_bit_size %>(src)) + if n < math.MinInt<%= dst_bit_size %> { + return fmt.Errorf("%d is less than minimum value for int<%= dst_bit_size %>", n) + } else if n > math.MaxInt<%= dst_bit_size %> { + return fmt.Errorf("%d is greater than maximum value for int<%= dst_bit_size %>", n) + } + + *p = int<%= dst_bit_size %>(n) + <% elsif dst_bit_size == pg_bit_size %> + *p = int<%= dst_bit_size %>(binary.BigEndian.Uint<%= pg_bit_size %>(src)) + <% else %> + *p = int<%= dst_bit_size %>(int<%= pg_bit_size %>(binary.BigEndian.Uint<%= pg_bit_size %>(src))) + <% end %> + + return nil +} + +type scanPlanBinaryInt<%= pg_byte_size %>ToUint<%= dst_bit_size %> struct{} + +func (scanPlanBinaryInt<%= pg_byte_size %>ToUint<%= dst_bit_size %>) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != <%= pg_byte_size %> { + return fmt.Errorf("invalid length for uint<%= pg_byte_size %>: %v", len(src)) + } + + p, ok := (dst).(*uint<%= dst_bit_size %>) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int<%= pg_bit_size %>(binary.BigEndian.Uint<%= pg_bit_size %>(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint<%= dst_bit_size %>", n) + } + <% if dst_bit_size < pg_bit_size %> + if n > math.MaxUint<%= dst_bit_size %> { + return fmt.Errorf("%d is greater than maximum value for uint<%= dst_bit_size %>", n) + } + <% end %> + *p = uint<%= dst_bit_size %>(n) + + return nil +} +<% end %> + +<%# PostgreSQL binary format integer to Go machine integers %> +type scanPlanBinaryInt<%= pg_byte_size %>ToInt struct{} + +func (scanPlanBinaryInt<%= pg_byte_size %>ToInt) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != <%= pg_byte_size %> { + return fmt.Errorf("invalid length for int<%= pg_byte_size %>: %v", len(src)) + } + + p, ok := (dst).(*int) + if !ok { + return ErrScanTargetTypeChanged + } + + <% if 32 < pg_bit_size %> + n := int64(binary.BigEndian.Uint<%= pg_bit_size %>(src)) + if n < math.MinInt { + return fmt.Errorf("%d is less than minimum value for int", n) + } else if n > math.MaxInt { + return fmt.Errorf("%d is greater than maximum value for int", n) + } + + *p = int(n) + <% else %> + *p = int(int<%= pg_bit_size %>(binary.BigEndian.Uint<%= pg_bit_size %>(src))) + <% end %> + + return nil +} + +type scanPlanBinaryInt<%= pg_byte_size %>ToUint struct{} + +func (scanPlanBinaryInt<%= pg_byte_size %>ToUint) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != <%= pg_byte_size %> { + return fmt.Errorf("invalid length for uint<%= pg_byte_size %>: %v", len(src)) + } + + p, ok := (dst).(*uint) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int64(int<%= pg_bit_size %>(binary.BigEndian.Uint<%= pg_bit_size %>(src))) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint", n) + } + <% if 32 < pg_bit_size %> + if uint64(n) > math.MaxUint { + return fmt.Errorf("%d is greater than maximum value for uint", n) + } + <% end %> + *p = uint(n) + + return nil +} + +<%# PostgreSQL binary format integer to Go Int64Scanner %> +type scanPlanBinaryInt<%= pg_byte_size %>ToInt64Scanner struct{} + +func (scanPlanBinaryInt<%= pg_byte_size %>ToInt64Scanner) Scan(src []byte, dst any) error { + s, ok := (dst).(Int64Scanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanInt64(Int8{}) + } + + if len(src) != <%= pg_byte_size %> { + return fmt.Errorf("invalid length for int<%= pg_byte_size %>: %v", len(src)) + } + + + n := int64(int<%= pg_bit_size %>(binary.BigEndian.Uint<%= pg_bit_size %>(src))) + + return s.ScanInt64(Int8{Int64: n, Valid: true}) +} + +<%# PostgreSQL binary format integer to Go TextScanner %> +type scanPlanBinaryInt<%= pg_byte_size %>ToTextScanner struct{} + +func (scanPlanBinaryInt<%= pg_byte_size %>ToTextScanner) Scan(src []byte, dst any) error { + s, ok := (dst).(TextScanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanText(Text{}) + } + + if len(src) != <%= pg_byte_size %> { + return fmt.Errorf("invalid length for int<%= pg_byte_size %>: %v", len(src)) + } + + + n := int64(int<%= pg_bit_size %>(binary.BigEndian.Uint<%= pg_bit_size %>(src))) + + return s.ScanText(Text{String: strconv.FormatInt(n, 10), Valid: true}) +} +<% end %> + +<%# Any text to all integer types %> +<% [ + ["8", 8], + ["16", 16], + ["32", 32], + ["64", 64], + ["", 0] +].each do |type_suffix, bit_size| %> +type scanPlanTextAnyToInt<%= type_suffix %> struct{} + +func (scanPlanTextAnyToInt<%= type_suffix %>) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + p, ok := (dst).(*int<%= type_suffix %>) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseInt(string(src), 10, <%= bit_size %>) + if err != nil { + return err + } + + *p = int<%= type_suffix %>(n) + return nil +} + +type scanPlanTextAnyToUint<%= type_suffix %> struct{} + +func (scanPlanTextAnyToUint<%= type_suffix %>) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + p, ok := (dst).(*uint<%= type_suffix %>) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseUint(string(src), 10, <%= bit_size %>) + if err != nil { + return err + } + + *p = uint<%= type_suffix %>(n) + return nil +} +<% end %> + +type scanPlanTextAnyToInt64Scanner struct{} + +func (scanPlanTextAnyToInt64Scanner) Scan(src []byte, dst any) error { + s, ok := (dst).(Int64Scanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanInt64(Int8{}) + } + + n, err := strconv.ParseInt(string(src), 10, 64) + if err != nil { + return err + } + + err = s.ScanInt64(Int8{Int64: n, Valid: true}) + if err != nil { + return err + } + + return nil +} diff --git a/pgtype/int_test.go b/pgtype/int_test.go new file mode 100644 index 00000000..73294b3c --- /dev/null +++ b/pgtype/int_test.go @@ -0,0 +1,257 @@ +// Do not edit. Generated from pgtype/int_test.go.erb +package pgtype_test + +import ( + "context" + "math" + "testing" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" +) + +func TestInt2Codec(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int2", []pgxtest.ValueRoundTripTest{ + {int8(1), new(int16), isExpectedEq(int16(1))}, + {int16(1), new(int16), isExpectedEq(int16(1))}, + {int32(1), new(int16), isExpectedEq(int16(1))}, + {int64(1), new(int16), isExpectedEq(int16(1))}, + {uint8(1), new(int16), isExpectedEq(int16(1))}, + {uint16(1), new(int16), isExpectedEq(int16(1))}, + {uint32(1), new(int16), isExpectedEq(int16(1))}, + {uint64(1), new(int16), isExpectedEq(int16(1))}, + {int(1), new(int16), isExpectedEq(int16(1))}, + {uint(1), new(int16), isExpectedEq(int16(1))}, + {pgtype.Int2{Int16: 1, Valid: true}, new(int16), isExpectedEq(int16(1))}, + {int32(-1), new(pgtype.Int2), isExpectedEq(pgtype.Int2{Int16: -1, Valid: true})}, + {1, new(int8), isExpectedEq(int8(1))}, + {1, new(int16), isExpectedEq(int16(1))}, + {1, new(int32), isExpectedEq(int32(1))}, + {1, new(int64), isExpectedEq(int64(1))}, + {1, new(uint8), isExpectedEq(uint8(1))}, + {1, new(uint16), isExpectedEq(uint16(1))}, + {1, new(uint32), isExpectedEq(uint32(1))}, + {1, new(uint64), isExpectedEq(uint64(1))}, + {1, new(int), isExpectedEq(int(1))}, + {1, new(uint), isExpectedEq(uint(1))}, + {-1, new(int8), isExpectedEq(int8(-1))}, + {-1, new(int16), isExpectedEq(int16(-1))}, + {-1, new(int32), isExpectedEq(int32(-1))}, + {-1, new(int64), isExpectedEq(int64(-1))}, + {-1, new(int), isExpectedEq(int(-1))}, + {math.MinInt16, new(int16), isExpectedEq(int16(math.MinInt16))}, + {-1, new(int16), isExpectedEq(int16(-1))}, + {0, new(int16), isExpectedEq(int16(0))}, + {1, new(int16), isExpectedEq(int16(1))}, + {math.MaxInt16, new(int16), isExpectedEq(int16(math.MaxInt16))}, + {1, new(pgtype.Int2), isExpectedEq(pgtype.Int2{Int16: 1, Valid: true})}, + {"1", new(string), isExpectedEq("1")}, + {pgtype.Int2{}, new(pgtype.Int2), isExpectedEq(pgtype.Int2{})}, + {nil, new(*int16), isExpectedEq((*int16)(nil))}, + }) +} + +func TestInt2MarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Int2 + result string + }{ + {source: pgtype.Int2{Int16: 0}, result: "null"}, + {source: pgtype.Int2{Int16: 1, Valid: true}, result: "1"}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + +func TestInt2UnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Int2 + }{ + {source: "null", result: pgtype.Int2{Int16: 0}}, + {source: "1", result: pgtype.Int2{Int16: 1, Valid: true}}, + } + for i, tt := range successfulTests { + var r pgtype.Int2 + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestInt4Codec(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int4", []pgxtest.ValueRoundTripTest{ + {int8(1), new(int32), isExpectedEq(int32(1))}, + {int16(1), new(int32), isExpectedEq(int32(1))}, + {int32(1), new(int32), isExpectedEq(int32(1))}, + {int64(1), new(int32), isExpectedEq(int32(1))}, + {uint8(1), new(int32), isExpectedEq(int32(1))}, + {uint16(1), new(int32), isExpectedEq(int32(1))}, + {uint32(1), new(int32), isExpectedEq(int32(1))}, + {uint64(1), new(int32), isExpectedEq(int32(1))}, + {int(1), new(int32), isExpectedEq(int32(1))}, + {uint(1), new(int32), isExpectedEq(int32(1))}, + {pgtype.Int4{Int32: 1, Valid: true}, new(int32), isExpectedEq(int32(1))}, + {int32(-1), new(pgtype.Int4), isExpectedEq(pgtype.Int4{Int32: -1, Valid: true})}, + {1, new(int8), isExpectedEq(int8(1))}, + {1, new(int16), isExpectedEq(int16(1))}, + {1, new(int32), isExpectedEq(int32(1))}, + {1, new(int64), isExpectedEq(int64(1))}, + {1, new(uint8), isExpectedEq(uint8(1))}, + {1, new(uint16), isExpectedEq(uint16(1))}, + {1, new(uint32), isExpectedEq(uint32(1))}, + {1, new(uint64), isExpectedEq(uint64(1))}, + {1, new(int), isExpectedEq(int(1))}, + {1, new(uint), isExpectedEq(uint(1))}, + {-1, new(int8), isExpectedEq(int8(-1))}, + {-1, new(int16), isExpectedEq(int16(-1))}, + {-1, new(int32), isExpectedEq(int32(-1))}, + {-1, new(int64), isExpectedEq(int64(-1))}, + {-1, new(int), isExpectedEq(int(-1))}, + {math.MinInt32, new(int32), isExpectedEq(int32(math.MinInt32))}, + {-1, new(int32), isExpectedEq(int32(-1))}, + {0, new(int32), isExpectedEq(int32(0))}, + {1, new(int32), isExpectedEq(int32(1))}, + {math.MaxInt32, new(int32), isExpectedEq(int32(math.MaxInt32))}, + {1, new(pgtype.Int4), isExpectedEq(pgtype.Int4{Int32: 1, Valid: true})}, + {"1", new(string), isExpectedEq("1")}, + {pgtype.Int4{}, new(pgtype.Int4), isExpectedEq(pgtype.Int4{})}, + {nil, new(*int32), isExpectedEq((*int32)(nil))}, + }) +} + +func TestInt4MarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Int4 + result string + }{ + {source: pgtype.Int4{Int32: 0}, result: "null"}, + {source: pgtype.Int4{Int32: 1, Valid: true}, result: "1"}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + +func TestInt4UnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Int4 + }{ + {source: "null", result: pgtype.Int4{Int32: 0}}, + {source: "1", result: pgtype.Int4{Int32: 1, Valid: true}}, + } + for i, tt := range successfulTests { + var r pgtype.Int4 + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestInt8Codec(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int8", []pgxtest.ValueRoundTripTest{ + {int8(1), new(int64), isExpectedEq(int64(1))}, + {int16(1), new(int64), isExpectedEq(int64(1))}, + {int32(1), new(int64), isExpectedEq(int64(1))}, + {int64(1), new(int64), isExpectedEq(int64(1))}, + {uint8(1), new(int64), isExpectedEq(int64(1))}, + {uint16(1), new(int64), isExpectedEq(int64(1))}, + {uint32(1), new(int64), isExpectedEq(int64(1))}, + {uint64(1), new(int64), isExpectedEq(int64(1))}, + {int(1), new(int64), isExpectedEq(int64(1))}, + {uint(1), new(int64), isExpectedEq(int64(1))}, + {pgtype.Int8{Int64: 1, Valid: true}, new(int64), isExpectedEq(int64(1))}, + {int32(-1), new(pgtype.Int8), isExpectedEq(pgtype.Int8{Int64: -1, Valid: true})}, + {1, new(int8), isExpectedEq(int8(1))}, + {1, new(int16), isExpectedEq(int16(1))}, + {1, new(int32), isExpectedEq(int32(1))}, + {1, new(int64), isExpectedEq(int64(1))}, + {1, new(uint8), isExpectedEq(uint8(1))}, + {1, new(uint16), isExpectedEq(uint16(1))}, + {1, new(uint32), isExpectedEq(uint32(1))}, + {1, new(uint64), isExpectedEq(uint64(1))}, + {1, new(int), isExpectedEq(int(1))}, + {1, new(uint), isExpectedEq(uint(1))}, + {-1, new(int8), isExpectedEq(int8(-1))}, + {-1, new(int16), isExpectedEq(int16(-1))}, + {-1, new(int32), isExpectedEq(int32(-1))}, + {-1, new(int64), isExpectedEq(int64(-1))}, + {-1, new(int), isExpectedEq(int(-1))}, + {math.MinInt64, new(int64), isExpectedEq(int64(math.MinInt64))}, + {-1, new(int64), isExpectedEq(int64(-1))}, + {0, new(int64), isExpectedEq(int64(0))}, + {1, new(int64), isExpectedEq(int64(1))}, + {math.MaxInt64, new(int64), isExpectedEq(int64(math.MaxInt64))}, + {1, new(pgtype.Int8), isExpectedEq(pgtype.Int8{Int64: 1, Valid: true})}, + {"1", new(string), isExpectedEq("1")}, + {pgtype.Int8{}, new(pgtype.Int8), isExpectedEq(pgtype.Int8{})}, + {nil, new(*int64), isExpectedEq((*int64)(nil))}, + }) +} + +func TestInt8MarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Int8 + result string + }{ + {source: pgtype.Int8{Int64: 0}, result: "null"}, + {source: pgtype.Int8{Int64: 1, Valid: true}, result: "1"}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + +func TestInt8UnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Int8 + }{ + {source: "null", result: pgtype.Int8{Int64: 0}}, + {source: "1", result: pgtype.Int8{Int64: 1, Valid: true}}, + } + for i, tt := range successfulTests { + var r pgtype.Int8 + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/pgtype/int_test.go.erb b/pgtype/int_test.go.erb new file mode 100644 index 00000000..ac9a3f14 --- /dev/null +++ b/pgtype/int_test.go.erb @@ -0,0 +1,93 @@ +package pgtype_test + +import ( + "math" + "testing" + + "github.com/jackc/pgx/v5/pgtype" +) + +<% [2, 4, 8].each do |pg_byte_size| %> +<% pg_bit_size = pg_byte_size * 8 %> +func TestInt<%= pg_byte_size %>Codec(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int<%= pg_byte_size %>", []pgxtest.ValueRoundTripTest{ + {int8(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {int16(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {int32(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {int64(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {uint8(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {uint16(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {uint32(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {uint64(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {int(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {uint(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {pgtype.Int<%= pg_byte_size %>{Int<%= pg_bit_size %>: 1, Valid: true}, new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {int32(-1), new(pgtype.Int<%= pg_byte_size %>), isExpectedEq(pgtype.Int<%= pg_byte_size %>{Int<%= pg_bit_size %>: -1, Valid: true})}, + {1, new(int8), isExpectedEq(int8(1))}, + {1, new(int16), isExpectedEq(int16(1))}, + {1, new(int32), isExpectedEq(int32(1))}, + {1, new(int64), isExpectedEq(int64(1))}, + {1, new(uint8), isExpectedEq(uint8(1))}, + {1, new(uint16), isExpectedEq(uint16(1))}, + {1, new(uint32), isExpectedEq(uint32(1))}, + {1, new(uint64), isExpectedEq(uint64(1))}, + {1, new(int), isExpectedEq(int(1))}, + {1, new(uint), isExpectedEq(uint(1))}, + {-1, new(int8), isExpectedEq(int8(-1))}, + {-1, new(int16), isExpectedEq(int16(-1))}, + {-1, new(int32), isExpectedEq(int32(-1))}, + {-1, new(int64), isExpectedEq(int64(-1))}, + {-1, new(int), isExpectedEq(int(-1))}, + {math.MinInt<%= pg_bit_size %>, new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(math.MinInt<%= pg_bit_size %>))}, + {-1, new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(-1))}, + {0, new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(0))}, + {1, new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {math.MaxInt<%= pg_bit_size %>, new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(math.MaxInt<%= pg_bit_size %>))}, + {1, new(pgtype.Int<%= pg_byte_size %>), isExpectedEq(pgtype.Int<%= pg_byte_size %>{Int<%= pg_bit_size %>: 1, Valid: true})}, + {"1", new(string), isExpectedEq("1")}, + {pgtype.Int<%= pg_byte_size %>{}, new(pgtype.Int<%= pg_byte_size %>), isExpectedEq(pgtype.Int<%= pg_byte_size %>{})}, + {nil, new(*int<%= pg_bit_size %>), isExpectedEq((*int<%= pg_bit_size %>)(nil))}, + }) +} + +func TestInt<%= pg_byte_size %>MarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Int<%= pg_byte_size %> + result string + }{ + {source: pgtype.Int<%= pg_byte_size %>{Int<%= pg_bit_size %>: 0}, result: "null"}, + {source: pgtype.Int<%= pg_byte_size %>{Int<%= pg_bit_size %>: 1, Valid: true}, result: "1"}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + +func TestInt<%= pg_byte_size %>UnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Int<%= pg_byte_size %> + }{ + {source: "null", result: pgtype.Int<%= pg_byte_size %>{Int<%= pg_bit_size %>: 0}}, + {source: "1", result: pgtype.Int<%= pg_byte_size %>{Int<%= pg_bit_size %>: 1, Valid: true}}, + } + for i, tt := range successfulTests { + var r pgtype.Int<%= pg_byte_size %> + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} +<% end %> diff --git a/pgtype/integration_benchmark_test.go b/pgtype/integration_benchmark_test.go new file mode 100644 index 00000000..22ac3344 --- /dev/null +++ b/pgtype/integration_benchmark_test.go @@ -0,0 +1,1269 @@ +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" +) + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int16_1_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int16 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0 from generate_series(1, 1) n`, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int16_1_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int16 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0 from generate_series(1, 1) n`, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int16_1_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int16 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int16_1_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int16 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int16_10_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int16 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0 from generate_series(1, 10) n`, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int16_10_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int16 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0 from generate_series(1, 10) n`, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int16_100_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int16 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int16_100_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int16 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int32_1_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int32 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0 from generate_series(1, 1) n`, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int32_1_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int32 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0 from generate_series(1, 1) n`, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int32_1_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int32 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int32_1_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int32 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int32_10_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int32 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0 from generate_series(1, 10) n`, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int32_10_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int32 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0 from generate_series(1, 10) n`, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int32_100_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int32 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int32_100_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int32 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int64_1_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0 from generate_series(1, 1) n`, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int64_1_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0 from generate_series(1, 1) n`, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int64_1_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int64_1_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int64_10_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0 from generate_series(1, 10) n`, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int64_10_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0 from generate_series(1, 10) n`, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int64_100_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int64_100_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_uint64_1_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]uint64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0 from generate_series(1, 1) n`, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_uint64_1_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]uint64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0 from generate_series(1, 1) n`, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_uint64_1_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]uint64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_uint64_1_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]uint64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_uint64_10_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]uint64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0 from generate_series(1, 10) n`, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_uint64_10_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]uint64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0 from generate_series(1, 10) n`, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_uint64_100_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]uint64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_uint64_100_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]uint64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_pgtype_Int4_1_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]pgtype.Int4 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0 from generate_series(1, 1) n`, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_pgtype_Int4_1_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]pgtype.Int4 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0 from generate_series(1, 1) n`, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_pgtype_Int4_1_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]pgtype.Int4 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_pgtype_Int4_1_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]pgtype.Int4 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_pgtype_Int4_10_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]pgtype.Int4 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0 from generate_series(1, 10) n`, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_pgtype_Int4_10_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]pgtype.Int4 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0 from generate_series(1, 10) n`, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_pgtype_Int4_100_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]pgtype.Int4 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_pgtype_Int4_100_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]pgtype.Int4 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_int64_1_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0 from generate_series(1, 1) n`, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_int64_1_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0 from generate_series(1, 1) n`, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_int64_1_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_int64_1_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_int64_10_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0 from generate_series(1, 10) n`, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_int64_10_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0 from generate_series(1, 10) n`, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_int64_100_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_int64_100_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_float64_1_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]float64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0 from generate_series(1, 1) n`, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_float64_1_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]float64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0 from generate_series(1, 1) n`, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_float64_1_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]float64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_float64_1_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]float64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_float64_10_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]float64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0 from generate_series(1, 10) n`, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_float64_10_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]float64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0 from generate_series(1, 10) n`, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_float64_100_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]float64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_float64_100_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]float64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_pgtype_Numeric_1_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]pgtype.Numeric + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0 from generate_series(1, 1) n`, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_pgtype_Numeric_1_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]pgtype.Numeric + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0 from generate_series(1, 1) n`, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_pgtype_Numeric_1_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]pgtype.Numeric + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_pgtype_Numeric_1_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]pgtype.Numeric + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_pgtype_Numeric_10_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]pgtype.Numeric + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0 from generate_series(1, 10) n`, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_pgtype_Numeric_10_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]pgtype.Numeric + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0 from generate_series(1, 10) n`, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_pgtype_Numeric_100_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]pgtype.Numeric + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_pgtype_Numeric_100_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]pgtype.Numeric + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_Int4Array_With_Go_Int4Array_10(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v []int32 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select array_agg(n) from generate_series(1, 10) n`, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_Int4Array_With_Go_Int4Array_10(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v []int32 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select array_agg(n) from generate_series(1, 10) n`, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_Int4Array_With_Go_Int4Array_100(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v []int32 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select array_agg(n) from generate_series(1, 100) n`, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_Int4Array_With_Go_Int4Array_100(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v []int32 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select array_agg(n) from generate_series(1, 100) n`, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_Int4Array_With_Go_Int4Array_1000(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v []int32 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select array_agg(n) from generate_series(1, 1000) n`, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_Int4Array_With_Go_Int4Array_1000(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v []int32 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select array_agg(n) from generate_series(1, 1000) n`, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} diff --git a/pgtype/integration_benchmark_test.go.erb b/pgtype/integration_benchmark_test.go.erb new file mode 100644 index 00000000..0175700a --- /dev/null +++ b/pgtype/integration_benchmark_test.go.erb @@ -0,0 +1,62 @@ +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5" +) + +<% + [ + ["int4", ["int16", "int32", "int64", "uint64", "pgtype.Int4"], [[1, 1], [1, 10], [10, 1], [100, 10]]], + ["numeric", ["int64", "float64", "pgtype.Numeric"], [[1, 1], [1, 10], [10, 1], [100, 10]]], + ].each do |pg_type, go_types, rows_columns| +%> +<% go_types.each do |go_type| %> +<% rows_columns.each do |rows, columns| %> +<% [["Text", "pgx.TextFormatCode"], ["Binary", "pgx.BinaryFormatCode"]].each do |format_name, format_code| %> +func BenchmarkQuery<%= format_name %>FormatDecode_PG_<%= pg_type %>_to_Go_<%= go_type.gsub(/\W/, "_") %>_<%= rows %>_rows_<%= columns %>_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [<%= columns %>]<%= go_type %> + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select <% columns.times do |col_idx| %><% if col_idx != 0 %>, <% end %>n::<%= pg_type %> + <%= col_idx%><% end %> from generate_series(1, <%= rows %>) n`, + []any{pgx.QueryResultFormats{<%= format_code %>}}, + ) + _, err := pgx.ForEachRow(rows, []any{<% columns.times do |col_idx| %><% if col_idx != 0 %>, <% end %>&v[<%= col_idx%>]<% end %>}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} +<% end %> +<% end %> +<% end %> +<% end %> + +<% [10, 100, 1000].each do |array_size| %> +<% [["Text", "pgx.TextFormatCode"], ["Binary", "pgx.BinaryFormatCode"]].each do |format_name, format_code| %> +func BenchmarkQuery<%= format_name %>FormatDecode_PG_Int4Array_With_Go_Int4Array_<%= array_size %>(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v []int32 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select array_agg(n) from generate_series(1, <%= array_size %>) n`, + []any{pgx.QueryResultFormats{<%= format_code %>}}, + ) + _, err := pgx.ForEachRow(rows, []any{&v}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} +<% end %> +<% end %> diff --git a/pgtype/integration_benchmark_test_gen.sh b/pgtype/integration_benchmark_test_gen.sh new file mode 100755 index 00000000..22ac01aa --- /dev/null +++ b/pgtype/integration_benchmark_test_gen.sh @@ -0,0 +1,2 @@ +erb integration_benchmark_test.go.erb > integration_benchmark_test.go +goimports -w integration_benchmark_test.go diff --git a/pgtype/interval.go b/pgtype/interval.go new file mode 100644 index 00000000..a172ecdb --- /dev/null +++ b/pgtype/interval.go @@ -0,0 +1,292 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "strconv" + "strings" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +const ( + microsecondsPerSecond = 1000000 + microsecondsPerMinute = 60 * microsecondsPerSecond + microsecondsPerHour = 60 * microsecondsPerMinute + microsecondsPerDay = 24 * microsecondsPerHour + microsecondsPerMonth = 30 * microsecondsPerDay +) + +type IntervalScanner interface { + ScanInterval(v Interval) error +} + +type IntervalValuer interface { + IntervalValue() (Interval, error) +} + +type Interval struct { + Microseconds int64 + Days int32 + Months int32 + Valid bool +} + +func (interval *Interval) ScanInterval(v Interval) error { + *interval = v + return nil +} + +func (interval Interval) IntervalValue() (Interval, error) { + return interval, nil +} + +// Scan implements the database/sql Scanner interface. +func (interval *Interval) Scan(src any) error { + if src == nil { + *interval = Interval{} + return nil + } + + switch src := src.(type) { + case string: + return scanPlanTextAnyToIntervalScanner{}.Scan([]byte(src), interval) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (interval Interval) Value() (driver.Value, error) { + if !interval.Valid { + return nil, nil + } + + buf, err := IntervalCodec{}.PlanEncode(nil, 0, TextFormatCode, interval).Encode(interval, nil) + if err != nil { + return nil, err + } + return string(buf), err +} + +type IntervalCodec struct{} + +func (IntervalCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (IntervalCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (IntervalCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(IntervalValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return encodePlanIntervalCodecBinary{} + case TextFormatCode: + return encodePlanIntervalCodecText{} + } + + return nil +} + +type encodePlanIntervalCodecBinary struct{} + +func (encodePlanIntervalCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + interval, err := value.(IntervalValuer).IntervalValue() + if err != nil { + return nil, err + } + + if !interval.Valid { + return nil, nil + } + + buf = pgio.AppendInt64(buf, interval.Microseconds) + buf = pgio.AppendInt32(buf, interval.Days) + buf = pgio.AppendInt32(buf, interval.Months) + return buf, nil +} + +type encodePlanIntervalCodecText struct{} + +func (encodePlanIntervalCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { + interval, err := value.(IntervalValuer).IntervalValue() + if err != nil { + return nil, err + } + + if !interval.Valid { + return nil, nil + } + + if interval.Months != 0 { + buf = append(buf, strconv.FormatInt(int64(interval.Months), 10)...) + buf = append(buf, " mon "...) + } + + if interval.Days != 0 { + buf = append(buf, strconv.FormatInt(int64(interval.Days), 10)...) + buf = append(buf, " day "...) + } + + absMicroseconds := interval.Microseconds + if absMicroseconds < 0 { + absMicroseconds = -absMicroseconds + buf = append(buf, '-') + } + + hours := absMicroseconds / microsecondsPerHour + minutes := (absMicroseconds % microsecondsPerHour) / microsecondsPerMinute + seconds := (absMicroseconds % microsecondsPerMinute) / microsecondsPerSecond + microseconds := absMicroseconds % microsecondsPerSecond + + timeStr := fmt.Sprintf("%02d:%02d:%02d.%06d", hours, minutes, seconds, microseconds) + buf = append(buf, timeStr...) + return buf, nil +} + +func (IntervalCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case IntervalScanner: + return scanPlanBinaryIntervalToIntervalScanner{} + } + case TextFormatCode: + switch target.(type) { + case IntervalScanner: + return scanPlanTextAnyToIntervalScanner{} + } + } + + return nil +} + +type scanPlanBinaryIntervalToIntervalScanner struct{} + +func (scanPlanBinaryIntervalToIntervalScanner) Scan(src []byte, dst any) error { + scanner := (dst).(IntervalScanner) + + if src == nil { + return scanner.ScanInterval(Interval{}) + } + + if len(src) != 16 { + return fmt.Errorf("Received an invalid size for a interval: %d", len(src)) + } + + microseconds := int64(binary.BigEndian.Uint64(src)) + days := int32(binary.BigEndian.Uint32(src[8:])) + months := int32(binary.BigEndian.Uint32(src[12:])) + + return scanner.ScanInterval(Interval{Microseconds: microseconds, Days: days, Months: months, Valid: true}) +} + +type scanPlanTextAnyToIntervalScanner struct{} + +func (scanPlanTextAnyToIntervalScanner) Scan(src []byte, dst any) error { + scanner := (dst).(IntervalScanner) + + if src == nil { + return scanner.ScanInterval(Interval{}) + } + + var microseconds int64 + var days int32 + var months int32 + + parts := strings.Split(string(src), " ") + + for i := 0; i < len(parts)-1; i += 2 { + scalar, err := strconv.ParseInt(parts[i], 10, 64) + if err != nil { + return fmt.Errorf("bad interval format") + } + + switch parts[i+1] { + case "year", "years": + months += int32(scalar * 12) + case "mon", "mons": + months += int32(scalar) + case "day", "days": + days = int32(scalar) + } + } + + if len(parts)%2 == 1 { + timeParts := strings.SplitN(parts[len(parts)-1], ":", 3) + if len(timeParts) != 3 { + return fmt.Errorf("bad interval format") + } + + var negative bool + if timeParts[0][0] == '-' { + negative = true + timeParts[0] = timeParts[0][1:] + } + + hours, err := strconv.ParseInt(timeParts[0], 10, 64) + if err != nil { + return fmt.Errorf("bad interval hour format: %s", timeParts[0]) + } + + minutes, err := strconv.ParseInt(timeParts[1], 10, 64) + if err != nil { + return fmt.Errorf("bad interval minute format: %s", timeParts[1]) + } + + secondParts := strings.SplitN(timeParts[2], ".", 2) + + seconds, err := strconv.ParseInt(secondParts[0], 10, 64) + if err != nil { + return fmt.Errorf("bad interval second format: %s", secondParts[0]) + } + + var uSeconds int64 + if len(secondParts) == 2 { + uSeconds, err = strconv.ParseInt(secondParts[1], 10, 64) + if err != nil { + return fmt.Errorf("bad interval decimal format: %s", secondParts[1]) + } + + for i := 0; i < 6-len(secondParts[1]); i++ { + uSeconds *= 10 + } + } + + microseconds = hours * microsecondsPerHour + microseconds += minutes * microsecondsPerMinute + microseconds += seconds * microsecondsPerSecond + microseconds += uSeconds + + if negative { + microseconds = -microseconds + } + } + + return scanner.ScanInterval(Interval{Months: months, Days: days, Microseconds: microseconds, Valid: true}) +} + +func (c IntervalCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) +} + +func (c IntervalCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var interval Interval + err := codecScan(c, m, oid, format, src, &interval) + if err != nil { + return nil, err + } + return interval, nil +} diff --git a/pgtype/interval_test.go b/pgtype/interval_test.go new file mode 100644 index 00000000..754c44e3 --- /dev/null +++ b/pgtype/interval_test.go @@ -0,0 +1,138 @@ +package pgtype_test + +import ( + "context" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" +) + +func TestIntervalCodec(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "interval", []pgxtest.ValueRoundTripTest{ + { + pgtype.Interval{Microseconds: 1, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: 1, Valid: true}), + }, + { + pgtype.Interval{Microseconds: 1000000, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: 1000000, Valid: true}), + }, + { + pgtype.Interval{Microseconds: 1000001, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: 1000001, Valid: true}), + }, + { + pgtype.Interval{Microseconds: 123202800000000, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: 123202800000000, Valid: true}), + }, + { + pgtype.Interval{Days: 1, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Days: 1, Valid: true}), + }, + { + pgtype.Interval{Months: 1, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Months: 1, Valid: true}), + }, + { + pgtype.Interval{Months: 12, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Months: 12, Valid: true}), + }, + { + pgtype.Interval{Months: 13, Days: 15, Microseconds: 1000001, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Months: 13, Days: 15, Microseconds: 1000001, Valid: true}), + }, + { + pgtype.Interval{Microseconds: -1, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: -1, Valid: true}), + }, + { + pgtype.Interval{Microseconds: -1000000, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: -1000000, Valid: true}), + }, + { + pgtype.Interval{Microseconds: -1000001, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: -1000001, Valid: true}), + }, + { + pgtype.Interval{Microseconds: -123202800000000, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: -123202800000000, Valid: true}), + }, + { + pgtype.Interval{Days: -1, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Days: -1, Valid: true}), + }, + { + pgtype.Interval{Months: -1, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Months: -1, Valid: true}), + }, + { + pgtype.Interval{Months: -12, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Months: -12, Valid: true}), + }, + { + pgtype.Interval{Months: -13, Days: -15, Microseconds: -1000001, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Months: -13, Days: -15, Microseconds: -1000001, Valid: true}), + }, + { + "1 second", + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: 1000000, Valid: true}), + }, + { + "1.000001 second", + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: 1000001, Valid: true}), + }, + { + "34223 hours", + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: 123202800000000, Valid: true}), + }, + { + "1 day", + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Days: 1, Valid: true}), + }, + { + "1 month", + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Months: 1, Valid: true}), + }, + { + "1 year", + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Months: 12, Valid: true}), + }, + { + "-13 mon", + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Months: -13, Valid: true}), + }, + {time.Hour, new(time.Duration), isExpectedEq(time.Hour)}, + { + pgtype.Interval{Months: 1, Days: 1, Valid: true}, + new(time.Duration), + isExpectedEq(time.Duration(2678400000000000)), + }, + {pgtype.Interval{}, new(pgtype.Interval), isExpectedEq(pgtype.Interval{})}, + {nil, new(pgtype.Interval), isExpectedEq(pgtype.Interval{})}, + }) +} diff --git a/pgtype/json.go b/pgtype/json.go new file mode 100644 index 00000000..d0d98fc9 --- /dev/null +++ b/pgtype/json.go @@ -0,0 +1,161 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/json" + "fmt" + "reflect" +) + +type JSONCodec struct{} + +func (JSONCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (JSONCodec) PreferredFormat() int16 { + return TextFormatCode +} + +func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch value.(type) { + case string: + return encodePlanJSONCodecEitherFormatString{} + case []byte: + return encodePlanJSONCodecEitherFormatByteSlice{} + } + + // Because anything can be marshalled the normal wrapping in Map.PlanScan doesn't get a chance to run. So try the + // appropriate wrappers here. + for _, f := range []TryWrapEncodePlanFunc{ + TryWrapDerefPointerEncodePlan, + TryWrapFindUnderlyingTypeEncodePlan, + } { + if wrapperPlan, nextValue, ok := f(value); ok { + if nextPlan := c.PlanEncode(m, oid, format, nextValue); nextPlan != nil { + wrapperPlan.SetNext(nextPlan) + return wrapperPlan + } + } + } + + return encodePlanJSONCodecEitherFormatMarshal{} +} + +type encodePlanJSONCodecEitherFormatString struct{} + +func (encodePlanJSONCodecEitherFormatString) Encode(value any, buf []byte) (newBuf []byte, err error) { + jsonString := value.(string) + buf = append(buf, jsonString...) + return buf, nil +} + +type encodePlanJSONCodecEitherFormatByteSlice struct{} + +func (encodePlanJSONCodecEitherFormatByteSlice) Encode(value any, buf []byte) (newBuf []byte, err error) { + jsonBytes := value.([]byte) + if jsonBytes == nil { + return nil, nil + } + + buf = append(buf, jsonBytes...) + return buf, nil +} + +type encodePlanJSONCodecEitherFormatMarshal struct{} + +func (encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (newBuf []byte, err error) { + jsonBytes, err := json.Marshal(value) + if err != nil { + return nil, err + } + + buf = append(buf, jsonBytes...) + return buf, nil +} + +func (JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch target.(type) { + case *string: + return scanPlanAnyToString{} + case *[]byte: + return scanPlanJSONToByteSlice{} + case BytesScanner: + return scanPlanBinaryBytesToBytesScanner{} + default: + return scanPlanJSONToJSONUnmarshal{} + } + +} + +type scanPlanAnyToString struct{} + +func (scanPlanAnyToString) Scan(src []byte, dst any) error { + p := dst.(*string) + *p = string(src) + return nil +} + +type scanPlanJSONToByteSlice struct{} + +func (scanPlanJSONToByteSlice) Scan(src []byte, dst any) error { + dstBuf := dst.(*[]byte) + if src == nil { + *dstBuf = nil + return nil + } + + *dstBuf = make([]byte, len(src)) + copy(*dstBuf, src) + return nil +} + +type scanPlanJSONToBytesScanner struct{} + +func (scanPlanJSONToBytesScanner) Scan(src []byte, dst any) error { + scanner := (dst).(BytesScanner) + return scanner.ScanBytes(src) +} + +type scanPlanJSONToJSONUnmarshal struct{} + +func (scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error { + if src == nil { + dstValue := reflect.ValueOf(dst) + if dstValue.Kind() == reflect.Ptr { + el := dstValue.Elem() + switch el.Kind() { + case reflect.Ptr, reflect.Slice, reflect.Map: + el.Set(reflect.Zero(el.Type())) + return nil + } + } + + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + elem := reflect.ValueOf(dst).Elem() + elem.Set(reflect.Zero(elem.Type())) + + return json.Unmarshal(src, dst) +} + +func (c JSONCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + dstBuf := make([]byte, len(src)) + copy(dstBuf, src) + return dstBuf, nil +} + +func (c JSONCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var dst any + err := json.Unmarshal(src, &dst) + return dst, err +} diff --git a/pgtype/json_test.go b/pgtype/json_test.go new file mode 100644 index 00000000..db20e576 --- /dev/null +++ b/pgtype/json_test.go @@ -0,0 +1,108 @@ +package pgtype_test + +import ( + "context" + "testing" + + pgx "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/require" +) + +func isExpectedEqMap(a any) func(any) bool { + return func(v any) bool { + aa := a.(map[string]any) + bb := v.(map[string]any) + + if (aa == nil) != (bb == nil) { + return false + } + + if aa == nil { + return true + } + + if len(aa) != len(bb) { + return false + } + + for k := range aa { + if aa[k] != bb[k] { + return false + } + } + + return true + } +} + +func TestJSONCodec(t *testing.T) { + type jsonStruct struct { + Name string `json:"name"` + Age int `json:"age"` + } + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "json", []pgxtest.ValueRoundTripTest{ + {nil, new(*jsonStruct), isExpectedEq((*jsonStruct)(nil))}, + {map[string]any(nil), new(*string), isExpectedEq((*string)(nil))}, + {map[string]any(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, + {[]byte(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, + {nil, new([]byte), isExpectedEqBytes([]byte(nil))}, + }) + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{ + {[]byte("{}"), new([]byte), isExpectedEqBytes([]byte("{}"))}, + {[]byte("null"), new([]byte), isExpectedEqBytes([]byte("null"))}, + {[]byte("42"), new([]byte), isExpectedEqBytes([]byte("42"))}, + {[]byte(`"hello"`), new([]byte), isExpectedEqBytes([]byte(`"hello"`))}, + {[]byte(`"hello"`), new(string), isExpectedEq(`"hello"`)}, + {map[string]any{"foo": "bar"}, new(map[string]any), isExpectedEqMap(map[string]any{"foo": "bar"})}, + {jsonStruct{Name: "Adam", Age: 10}, new(jsonStruct), isExpectedEq(jsonStruct{Name: "Adam", Age: 10})}, + }) +} + +// https://github.com/jackc/pgx/issues/1273#issuecomment-1221414648 +func TestJSONCodecUnmarshalSQLNull(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + // Slices are nilified + slice := []string{"foo", "bar", "baz"} + err := conn.QueryRow(ctx, "select null::json").Scan(&slice) + require.NoError(t, err) + require.Nil(t, slice) + + // Maps are nilified + m := map[string]any{"foo": "bar"} + err = conn.QueryRow(ctx, "select null::json").Scan(&m) + require.NoError(t, err) + require.Nil(t, m) + + // Pointer to pointer are nilified + n := 42 + p := &n + err = conn.QueryRow(ctx, "select null::json").Scan(&p) + require.NoError(t, err) + require.Nil(t, p) + + // A string cannot scan a NULL. + str := "foobar" + err = conn.QueryRow(ctx, "select null::json").Scan(&str) + require.EqualError(t, err, "can't scan into dest[0]: cannot scan NULL into *string") + + // A non-string cannot scan a NULL. + err = conn.QueryRow(ctx, "select null::json").Scan(&n) + require.EqualError(t, err, "can't scan into dest[0]: cannot scan NULL into *int") + }) +} + +func TestJSONCodecClearExistingValueBeforeUnmarshal(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + m := map[string]any{} + err := conn.QueryRow(ctx, `select '{"foo": "bar"}'::json`).Scan(&m) + require.NoError(t, err) + require.Equal(t, map[string]any{"foo": "bar"}, m) + + err = conn.QueryRow(ctx, `select '{"baz": "quz"}'::json`).Scan(&m) + require.NoError(t, err) + require.Equal(t, map[string]any{"baz": "quz"}, m) + }) +} diff --git a/pgtype/jsonb.go b/pgtype/jsonb.go new file mode 100644 index 00000000..25555e7f --- /dev/null +++ b/pgtype/jsonb.go @@ -0,0 +1,127 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/json" + "fmt" +) + +type JSONBCodec struct{} + +func (JSONBCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (JSONBCodec) PreferredFormat() int16 { + return TextFormatCode +} + +func (JSONBCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch format { + case BinaryFormatCode: + plan := JSONCodec{}.PlanEncode(m, oid, TextFormatCode, value) + if plan != nil { + return &encodePlanJSONBCodecBinaryWrapper{textPlan: plan} + } + case TextFormatCode: + return JSONCodec{}.PlanEncode(m, oid, format, value) + } + + return nil +} + +type encodePlanJSONBCodecBinaryWrapper struct { + textPlan EncodePlan +} + +func (plan *encodePlanJSONBCodecBinaryWrapper) Encode(value any, buf []byte) (newBuf []byte, err error) { + buf = append(buf, 1) + return plan.textPlan.Encode(value, buf) +} + +func (JSONBCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case BinaryFormatCode: + plan := JSONCodec{}.PlanScan(m, oid, TextFormatCode, target) + if plan != nil { + return &scanPlanJSONBCodecBinaryUnwrapper{textPlan: plan} + } + case TextFormatCode: + return JSONCodec{}.PlanScan(m, oid, format, target) + } + + return nil +} + +type scanPlanJSONBCodecBinaryUnwrapper struct { + textPlan ScanPlan +} + +func (plan *scanPlanJSONBCodecBinaryUnwrapper) Scan(src []byte, dst any) error { + if src == nil { + return plan.textPlan.Scan(src, dst) + } + + if len(src) == 0 { + return fmt.Errorf("jsonb too short") + } + + if src[0] != 1 { + return fmt.Errorf("unknown jsonb version number %d", src[0]) + } + + return plan.textPlan.Scan(src[1:], dst) +} + +func (c JSONBCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + switch format { + case BinaryFormatCode: + if len(src) == 0 { + return nil, fmt.Errorf("jsonb too short") + } + + if src[0] != 1 { + return nil, fmt.Errorf("unknown jsonb version number %d", src[0]) + } + + dstBuf := make([]byte, len(src)-1) + copy(dstBuf, src[1:]) + return dstBuf, nil + case TextFormatCode: + dstBuf := make([]byte, len(src)) + copy(dstBuf, src) + return dstBuf, nil + default: + return nil, fmt.Errorf("unknown format code: %v", format) + } +} + +func (c JSONBCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + switch format { + case BinaryFormatCode: + if len(src) == 0 { + return nil, fmt.Errorf("jsonb too short") + } + + if src[0] != 1 { + return nil, fmt.Errorf("unknown jsonb version number %d", src[0]) + } + + src = src[1:] + case TextFormatCode: + default: + return nil, fmt.Errorf("unknown format code: %v", format) + } + + var dst any + err := json.Unmarshal(src, &dst) + return dst, err +} diff --git a/pgtype/jsonb_test.go b/pgtype/jsonb_test.go new file mode 100644 index 00000000..7dadc6c5 --- /dev/null +++ b/pgtype/jsonb_test.go @@ -0,0 +1,33 @@ +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5/pgxtest" +) + +func TestJSONBTranscode(t *testing.T) { + type jsonStruct struct { + Name string `json:"name"` + Age int `json:"age"` + } + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "jsonb", []pgxtest.ValueRoundTripTest{ + {nil, new(*jsonStruct), isExpectedEq((*jsonStruct)(nil))}, + {map[string]any(nil), new(*string), isExpectedEq((*string)(nil))}, + {map[string]any(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, + {[]byte(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, + {nil, new([]byte), isExpectedEqBytes([]byte(nil))}, + }) + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "jsonb", []pgxtest.ValueRoundTripTest{ + {[]byte("{}"), new([]byte), isExpectedEqBytes([]byte("{}"))}, + {[]byte("null"), new([]byte), isExpectedEqBytes([]byte("null"))}, + {[]byte("42"), new([]byte), isExpectedEqBytes([]byte("42"))}, + {[]byte(`"hello"`), new([]byte), isExpectedEqBytes([]byte(`"hello"`))}, + {[]byte(`"hello"`), new(string), isExpectedEq(`"hello"`)}, + {map[string]any{"foo": "bar"}, new(map[string]any), isExpectedEqMap(map[string]any{"foo": "bar"})}, + {jsonStruct{Name: "Adam", Age: 10}, new(jsonStruct), isExpectedEq(jsonStruct{Name: "Adam", Age: 10})}, + }) +} diff --git a/pgtype/line.go b/pgtype/line.go new file mode 100644 index 00000000..4ae8003e --- /dev/null +++ b/pgtype/line.go @@ -0,0 +1,225 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + "strings" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type LineScanner interface { + ScanLine(v Line) error +} + +type LineValuer interface { + LineValue() (Line, error) +} + +type Line struct { + A, B, C float64 + Valid bool +} + +func (line *Line) ScanLine(v Line) error { + *line = v + return nil +} + +func (line Line) LineValue() (Line, error) { + return line, nil +} + +func (line *Line) Set(src any) error { + return fmt.Errorf("cannot convert %v to Line", src) +} + +// Scan implements the database/sql Scanner interface. +func (line *Line) Scan(src any) error { + if src == nil { + *line = Line{} + return nil + } + + switch src := src.(type) { + case string: + return scanPlanTextAnyToLineScanner{}.Scan([]byte(src), line) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (line Line) Value() (driver.Value, error) { + if !line.Valid { + return nil, nil + } + + buf, err := LineCodec{}.PlanEncode(nil, 0, TextFormatCode, line).Encode(line, nil) + if err != nil { + return nil, err + } + return string(buf), err +} + +type LineCodec struct{} + +func (LineCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (LineCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (LineCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(LineValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return encodePlanLineCodecBinary{} + case TextFormatCode: + return encodePlanLineCodecText{} + } + + return nil +} + +type encodePlanLineCodecBinary struct{} + +func (encodePlanLineCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + line, err := value.(LineValuer).LineValue() + if err != nil { + return nil, err + } + + if !line.Valid { + return nil, nil + } + + buf = pgio.AppendUint64(buf, math.Float64bits(line.A)) + buf = pgio.AppendUint64(buf, math.Float64bits(line.B)) + buf = pgio.AppendUint64(buf, math.Float64bits(line.C)) + return buf, nil +} + +type encodePlanLineCodecText struct{} + +func (encodePlanLineCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { + line, err := value.(LineValuer).LineValue() + if err != nil { + return nil, err + } + + if !line.Valid { + return nil, nil + } + + buf = append(buf, fmt.Sprintf(`{%s,%s,%s}`, + strconv.FormatFloat(line.A, 'f', -1, 64), + strconv.FormatFloat(line.B, 'f', -1, 64), + strconv.FormatFloat(line.C, 'f', -1, 64), + )...) + return buf, nil +} + +func (LineCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case LineScanner: + return scanPlanBinaryLineToLineScanner{} + } + case TextFormatCode: + switch target.(type) { + case LineScanner: + return scanPlanTextAnyToLineScanner{} + } + } + + return nil +} + +type scanPlanBinaryLineToLineScanner struct{} + +func (scanPlanBinaryLineToLineScanner) Scan(src []byte, dst any) error { + scanner := (dst).(LineScanner) + + if src == nil { + return scanner.ScanLine(Line{}) + } + + if len(src) != 24 { + return fmt.Errorf("invalid length for line: %v", len(src)) + } + + a := binary.BigEndian.Uint64(src) + b := binary.BigEndian.Uint64(src[8:]) + c := binary.BigEndian.Uint64(src[16:]) + + return scanner.ScanLine(Line{ + A: math.Float64frombits(a), + B: math.Float64frombits(b), + C: math.Float64frombits(c), + Valid: true, + }) +} + +type scanPlanTextAnyToLineScanner struct{} + +func (scanPlanTextAnyToLineScanner) Scan(src []byte, dst any) error { + scanner := (dst).(LineScanner) + + if src == nil { + return scanner.ScanLine(Line{}) + } + + if len(src) < 7 { + return fmt.Errorf("invalid length for line: %v", len(src)) + } + + parts := strings.SplitN(string(src[1:len(src)-1]), ",", 3) + if len(parts) < 3 { + return fmt.Errorf("invalid format for line") + } + + a, err := strconv.ParseFloat(parts[0], 64) + if err != nil { + return err + } + + b, err := strconv.ParseFloat(parts[1], 64) + if err != nil { + return err + } + + c, err := strconv.ParseFloat(parts[2], 64) + if err != nil { + return err + } + + return scanner.ScanLine(Line{A: a, B: b, C: c, Valid: true}) +} + +func (c LineCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) +} + +func (c LineCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var line Line + err := codecScan(c, m, oid, format, src, &line) + if err != nil { + return nil, err + } + return line, nil +} diff --git a/pgtype/line_test.go b/pgtype/line_test.go new file mode 100644 index 00000000..dc980ce1 --- /dev/null +++ b/pgtype/line_test.go @@ -0,0 +1,58 @@ +package pgtype_test + +import ( + "context" + "testing" + + pgx "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" +) + +func TestLineTranscode(t *testing.T) { + ctr := defaultConnTestRunner + ctr.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "Server does not support type line") + + if _, ok := conn.TypeMap().TypeForName("line"); !ok { + t.Skip("Skipping due to no line type") + } + + // line may exist but not be usable on 9.3 :( + var isPG93 bool + err := conn.QueryRow(context.Background(), "select version() ~ '9.3'").Scan(&isPG93) + if err != nil { + t.Fatal(err) + } + if isPG93 { + t.Skip("Skipping due to unimplemented line type in PG 9.3") + } + } + + pgxtest.RunValueRoundTripTests(context.Background(), t, ctr, nil, "line", []pgxtest.ValueRoundTripTest{ + { + pgtype.Line{ + A: 1.23, B: 4.56, C: 7.89012345, + Valid: true, + }, + new(pgtype.Line), + isExpectedEq(pgtype.Line{ + A: 1.23, B: 4.56, C: 7.89012345, + Valid: true, + }), + }, + { + pgtype.Line{ + A: -1.23, B: -4.56, C: -7.89, + Valid: true, + }, + new(pgtype.Line), + isExpectedEq(pgtype.Line{ + A: -1.23, B: -4.56, C: -7.89, + Valid: true, + }), + }, + {pgtype.Line{}, new(pgtype.Line), isExpectedEq(pgtype.Line{})}, + {nil, new(pgtype.Line), isExpectedEq(pgtype.Line{})}, + }) +} diff --git a/pgtype/lseg.go b/pgtype/lseg.go new file mode 100644 index 00000000..97f130dc --- /dev/null +++ b/pgtype/lseg.go @@ -0,0 +1,238 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + "strings" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type LsegScanner interface { + ScanLseg(v Lseg) error +} + +type LsegValuer interface { + LsegValue() (Lseg, error) +} + +type Lseg struct { + P [2]Vec2 + Valid bool +} + +func (lseg *Lseg) ScanLseg(v Lseg) error { + *lseg = v + return nil +} + +func (lseg Lseg) LsegValue() (Lseg, error) { + return lseg, nil +} + +// Scan implements the database/sql Scanner interface. +func (lseg *Lseg) Scan(src any) error { + if src == nil { + *lseg = Lseg{} + return nil + } + + switch src := src.(type) { + case string: + return scanPlanTextAnyToLsegScanner{}.Scan([]byte(src), lseg) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (lseg Lseg) Value() (driver.Value, error) { + if !lseg.Valid { + return nil, nil + } + + buf, err := LsegCodec{}.PlanEncode(nil, 0, TextFormatCode, lseg).Encode(lseg, nil) + if err != nil { + return nil, err + } + return string(buf), err +} + +type LsegCodec struct{} + +func (LsegCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (LsegCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (LsegCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(LsegValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return encodePlanLsegCodecBinary{} + case TextFormatCode: + return encodePlanLsegCodecText{} + } + + return nil +} + +type encodePlanLsegCodecBinary struct{} + +func (encodePlanLsegCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + lseg, err := value.(LsegValuer).LsegValue() + if err != nil { + return nil, err + } + + if !lseg.Valid { + return nil, nil + } + + buf = pgio.AppendUint64(buf, math.Float64bits(lseg.P[0].X)) + buf = pgio.AppendUint64(buf, math.Float64bits(lseg.P[0].Y)) + buf = pgio.AppendUint64(buf, math.Float64bits(lseg.P[1].X)) + buf = pgio.AppendUint64(buf, math.Float64bits(lseg.P[1].Y)) + return buf, nil +} + +type encodePlanLsegCodecText struct{} + +func (encodePlanLsegCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { + lseg, err := value.(LsegValuer).LsegValue() + if err != nil { + return nil, err + } + + if !lseg.Valid { + return nil, nil + } + + buf = append(buf, fmt.Sprintf(`(%s,%s),(%s,%s)`, + strconv.FormatFloat(lseg.P[0].X, 'f', -1, 64), + strconv.FormatFloat(lseg.P[0].Y, 'f', -1, 64), + strconv.FormatFloat(lseg.P[1].X, 'f', -1, 64), + strconv.FormatFloat(lseg.P[1].Y, 'f', -1, 64), + )...) + return buf, nil +} + +func (LsegCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case LsegScanner: + return scanPlanBinaryLsegToLsegScanner{} + } + case TextFormatCode: + switch target.(type) { + case LsegScanner: + return scanPlanTextAnyToLsegScanner{} + } + } + + return nil +} + +type scanPlanBinaryLsegToLsegScanner struct{} + +func (scanPlanBinaryLsegToLsegScanner) Scan(src []byte, dst any) error { + scanner := (dst).(LsegScanner) + + if src == nil { + return scanner.ScanLseg(Lseg{}) + } + + if len(src) != 32 { + return fmt.Errorf("invalid length for lseg: %v", len(src)) + } + + x1 := binary.BigEndian.Uint64(src) + y1 := binary.BigEndian.Uint64(src[8:]) + x2 := binary.BigEndian.Uint64(src[16:]) + y2 := binary.BigEndian.Uint64(src[24:]) + + return scanner.ScanLseg(Lseg{ + P: [2]Vec2{ + {math.Float64frombits(x1), math.Float64frombits(y1)}, + {math.Float64frombits(x2), math.Float64frombits(y2)}, + }, + Valid: true, + }) +} + +type scanPlanTextAnyToLsegScanner struct{} + +func (scanPlanTextAnyToLsegScanner) Scan(src []byte, dst any) error { + scanner := (dst).(LsegScanner) + + if src == nil { + return scanner.ScanLseg(Lseg{}) + } + + if len(src) < 11 { + return fmt.Errorf("invalid length for lseg: %v", len(src)) + } + + str := string(src[2:]) + + var end int + end = strings.IndexByte(str, ',') + + x1, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1:] + end = strings.IndexByte(str, ')') + + y1, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+3:] + end = strings.IndexByte(str, ',') + + x2, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1 : len(str)-2] + + y2, err := strconv.ParseFloat(str, 64) + if err != nil { + return err + } + + return scanner.ScanLseg(Lseg{P: [2]Vec2{{x1, y1}, {x2, y2}}, Valid: true}) +} + +func (c LsegCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) +} + +func (c LsegCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var lseg Lseg + err := codecScan(c, m, oid, format, src, &lseg) + if err != nil { + return nil, err + } + return lseg, nil +} diff --git a/pgtype/lseg_test.go b/pgtype/lseg_test.go new file mode 100644 index 00000000..04fde0eb --- /dev/null +++ b/pgtype/lseg_test.go @@ -0,0 +1,40 @@ +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" +) + +func TestLsegTranscode(t *testing.T) { + skipCockroachDB(t, "Server does not support type lseg") + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "lseg", []pgxtest.ValueRoundTripTest{ + { + pgtype.Lseg{ + P: [2]pgtype.Vec2{{3.14, 1.678}, {7.1, 5.2345678901}}, + Valid: true, + }, + new(pgtype.Lseg), + isExpectedEq(pgtype.Lseg{ + P: [2]pgtype.Vec2{{3.14, 1.678}, {7.1, 5.2345678901}}, + Valid: true, + }), + }, + { + pgtype.Lseg{ + P: [2]pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, + Valid: true, + }, + new(pgtype.Lseg), + isExpectedEq(pgtype.Lseg{ + P: [2]pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, + Valid: true, + }), + }, + {pgtype.Lseg{}, new(pgtype.Lseg), isExpectedEq(pgtype.Lseg{})}, + {nil, new(pgtype.Lseg), isExpectedEq(pgtype.Lseg{})}, + }) +} diff --git a/pgtype/macaddr.go b/pgtype/macaddr.go new file mode 100644 index 00000000..e913ec90 --- /dev/null +++ b/pgtype/macaddr.go @@ -0,0 +1,162 @@ +package pgtype + +import ( + "database/sql/driver" + "net" +) + +type MacaddrCodec struct{} + +func (MacaddrCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (MacaddrCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (MacaddrCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch format { + case BinaryFormatCode: + switch value.(type) { + case net.HardwareAddr: + return encodePlanMacaddrCodecBinaryHardwareAddr{} + case TextValuer: + return encodePlanMacAddrCodecTextValuer{} + + } + case TextFormatCode: + switch value.(type) { + case net.HardwareAddr: + return encodePlanMacaddrCodecTextHardwareAddr{} + case TextValuer: + return encodePlanTextCodecTextValuer{} + } + } + + return nil +} + +type encodePlanMacaddrCodecBinaryHardwareAddr struct{} + +func (encodePlanMacaddrCodecBinaryHardwareAddr) Encode(value any, buf []byte) (newBuf []byte, err error) { + addr := value.(net.HardwareAddr) + if addr == nil { + return nil, nil + } + + return append(buf, addr...), nil +} + +type encodePlanMacAddrCodecTextValuer struct{} + +func (encodePlanMacAddrCodecTextValuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + t, err := value.(TextValuer).TextValue() + if err != nil { + return nil, err + } + if !t.Valid { + return nil, nil + } + + addr, err := net.ParseMAC(t.String) + if err != nil { + return nil, err + } + + return append(buf, addr...), nil +} + +type encodePlanMacaddrCodecTextHardwareAddr struct{} + +func (encodePlanMacaddrCodecTextHardwareAddr) Encode(value any, buf []byte) (newBuf []byte, err error) { + addr := value.(net.HardwareAddr) + if addr == nil { + return nil, nil + } + + return append(buf, addr.String()...), nil +} + +func (MacaddrCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case *net.HardwareAddr: + return scanPlanBinaryMacaddrToHardwareAddr{} + case TextScanner: + return scanPlanBinaryMacaddrToTextScanner{} + } + case TextFormatCode: + switch target.(type) { + case *net.HardwareAddr: + return scanPlanTextMacaddrToHardwareAddr{} + case TextScanner: + return scanPlanTextAnyToTextScanner{} + } + } + + return nil +} + +type scanPlanBinaryMacaddrToHardwareAddr struct{} + +func (scanPlanBinaryMacaddrToHardwareAddr) Scan(src []byte, dst any) error { + dstBuf := dst.(*net.HardwareAddr) + if src == nil { + *dstBuf = nil + return nil + } + + *dstBuf = make([]byte, len(src)) + copy(*dstBuf, src) + return nil +} + +type scanPlanBinaryMacaddrToTextScanner struct{} + +func (scanPlanBinaryMacaddrToTextScanner) Scan(src []byte, dst any) error { + scanner := (dst).(TextScanner) + if src == nil { + return scanner.ScanText(Text{}) + } + + return scanner.ScanText(Text{String: net.HardwareAddr(src).String(), Valid: true}) +} + +type scanPlanTextMacaddrToHardwareAddr struct{} + +func (scanPlanTextMacaddrToHardwareAddr) Scan(src []byte, dst any) error { + p := dst.(*net.HardwareAddr) + + if src == nil { + *p = nil + return nil + } + + addr, err := net.ParseMAC(string(src)) + if err != nil { + return err + } + + *p = addr + + return nil +} + +func (c MacaddrCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) +} + +func (c MacaddrCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var addr net.HardwareAddr + err := codecScan(c, m, oid, format, src, &addr) + if err != nil { + return nil, err + } + return addr, nil +} diff --git a/pgtype/macaddr_test.go b/pgtype/macaddr_test.go new file mode 100644 index 00000000..ef6dae00 --- /dev/null +++ b/pgtype/macaddr_test.go @@ -0,0 +1,51 @@ +package pgtype_test + +import ( + "bytes" + "context" + "net" + "testing" + + "github.com/jackc/pgx/v5/pgxtest" +) + +func isExpectedEqHardwareAddr(a any) func(any) bool { + return func(v any) bool { + aa := a.(net.HardwareAddr) + vv := v.(net.HardwareAddr) + + if (aa == nil) != (vv == nil) { + return false + } + + if aa == nil { + return true + } + + return bytes.Compare(aa, vv) == 0 + } +} + +func TestMacaddrCodec(t *testing.T) { + skipCockroachDB(t, "Server does not support type macaddr") + + // Only testing known OID query exec modes as net.HardwareAddr could map to macaddr or macaddr8. + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "macaddr", []pgxtest.ValueRoundTripTest{ + { + mustParseMacaddr(t, "01:23:45:67:89:ab"), + new(net.HardwareAddr), + isExpectedEqHardwareAddr(mustParseMacaddr(t, "01:23:45:67:89:ab")), + }, + { + "01:23:45:67:89:ab", + new(net.HardwareAddr), + isExpectedEqHardwareAddr(mustParseMacaddr(t, "01:23:45:67:89:ab")), + }, + { + mustParseMacaddr(t, "01:23:45:67:89:ab"), + new(string), + isExpectedEq("01:23:45:67:89:ab"), + }, + {nil, new(*net.HardwareAddr), isExpectedEq((*net.HardwareAddr)(nil))}, + }) +} diff --git a/pgtype/multirange.go b/pgtype/multirange.go new file mode 100644 index 00000000..34950b34 --- /dev/null +++ b/pgtype/multirange.go @@ -0,0 +1,443 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +// MultirangeGetter is a type that can be converted into a PostgreSQL multirange. +type MultirangeGetter interface { + // IsNull returns true if the value is SQL NULL. + IsNull() bool + + // Len returns the number of elements in the multirange. + Len() int + + // Index returns the element at i. + Index(i int) any + + // IndexType returns a non-nil scan target of the type Index will return. This is used by MultirangeCodec.PlanEncode. + IndexType() any +} + +// MultirangeSetter is a type can be set from a PostgreSQL multirange. +type MultirangeSetter interface { + // ScanNull sets the value to SQL NULL. + ScanNull() error + + // SetLen prepares the value such that ScanIndex can be called for each element. This will remove any existing + // elements. + SetLen(n int) error + + // ScanIndex returns a value usable as a scan target for i. SetLen must be called before ScanIndex. + ScanIndex(i int) any + + // ScanIndexType returns a non-nil scan target of the type ScanIndex will return. This is used by + // MultirangeCodec.PlanScan. + ScanIndexType() any +} + +// MultirangeCodec is a codec for any multirange type. +type MultirangeCodec struct { + ElementType *Type +} + +func (c *MultirangeCodec) FormatSupported(format int16) bool { + return c.ElementType.Codec.FormatSupported(format) +} + +func (c *MultirangeCodec) PreferredFormat() int16 { + return c.ElementType.Codec.PreferredFormat() +} + +func (c *MultirangeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + multirangeValuer, ok := value.(MultirangeGetter) + if !ok { + return nil + } + + elementType := multirangeValuer.IndexType() + + elementEncodePlan := m.PlanEncode(c.ElementType.OID, format, elementType) + if elementEncodePlan == nil { + return nil + } + + switch format { + case BinaryFormatCode: + return &encodePlanMultirangeCodecBinary{ac: c, m: m, oid: oid} + case TextFormatCode: + return &encodePlanMultirangeCodecText{ac: c, m: m, oid: oid} + } + + return nil +} + +type encodePlanMultirangeCodecText struct { + ac *MultirangeCodec + m *Map + oid uint32 +} + +func (p *encodePlanMultirangeCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { + multirange := value.(MultirangeGetter) + + if multirange.IsNull() { + return nil, nil + } + + elementCount := multirange.Len() + + buf = append(buf, '{') + + var encodePlan EncodePlan + var lastElemType reflect.Type + inElemBuf := make([]byte, 0, 32) + for i := 0; i < elementCount; i++ { + if i > 0 { + buf = append(buf, ',') + } + + elem := multirange.Index(i) + var elemBuf []byte + if elem != nil { + elemType := reflect.TypeOf(elem) + if lastElemType != elemType { + lastElemType = elemType + encodePlan = p.m.PlanEncode(p.ac.ElementType.OID, TextFormatCode, elem) + if encodePlan == nil { + return nil, fmt.Errorf("unable to encode %v", multirange.Index(i)) + } + } + elemBuf, err = encodePlan.Encode(elem, inElemBuf) + if err != nil { + return nil, err + } + } + + if elemBuf == nil { + return nil, fmt.Errorf("multirange cannot contain NULL element") + } else { + buf = append(buf, elemBuf...) + } + } + + buf = append(buf, '}') + + return buf, nil +} + +type encodePlanMultirangeCodecBinary struct { + ac *MultirangeCodec + m *Map + oid uint32 +} + +func (p *encodePlanMultirangeCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + multirange := value.(MultirangeGetter) + + if multirange.IsNull() { + return nil, nil + } + + elementCount := multirange.Len() + + buf = pgio.AppendInt32(buf, int32(elementCount)) + + var encodePlan EncodePlan + var lastElemType reflect.Type + for i := 0; i < elementCount; i++ { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elem := multirange.Index(i) + var elemBuf []byte + if elem != nil { + elemType := reflect.TypeOf(elem) + if lastElemType != elemType { + lastElemType = elemType + encodePlan = p.m.PlanEncode(p.ac.ElementType.OID, BinaryFormatCode, elem) + if encodePlan == nil { + return nil, fmt.Errorf("unable to encode %v", multirange.Index(i)) + } + } + elemBuf, err = encodePlan.Encode(elem, buf) + if err != nil { + return nil, err + } + } + + if elemBuf == nil { + return nil, fmt.Errorf("multirange cannot contain NULL element") + } else { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +func (c *MultirangeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + multirangeScanner, ok := target.(MultirangeSetter) + if !ok { + return nil + } + + elementType := multirangeScanner.ScanIndexType() + + elementScanPlan := m.PlanScan(c.ElementType.OID, format, elementType) + if _, ok := elementScanPlan.(*scanPlanFail); ok { + return nil + } + + return &scanPlanMultirangeCodec{ + multirangeCodec: c, + m: m, + oid: oid, + formatCode: format, + } +} + +func (c *MultirangeCodec) decodeBinary(m *Map, multirangeOID uint32, src []byte, multirange MultirangeSetter) error { + rp := 0 + + elementCount := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + err := multirange.SetLen(elementCount) + if err != nil { + return err + } + + if elementCount == 0 { + return nil + } + + elementScanPlan := c.ElementType.Codec.PlanScan(m, c.ElementType.OID, BinaryFormatCode, multirange.ScanIndex(0)) + if elementScanPlan == nil { + elementScanPlan = m.PlanScan(c.ElementType.OID, BinaryFormatCode, multirange.ScanIndex(0)) + } + + for i := 0; i < elementCount; i++ { + elem := multirange.ScanIndex(i) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elementScanPlan.Scan(elemSrc, elem) + if err != nil { + return fmt.Errorf("failed to scan multirange element %d: %w", i, err) + } + } + + return nil +} + +func (c *MultirangeCodec) decodeText(m *Map, multirangeOID uint32, src []byte, multirange MultirangeSetter) error { + elements, err := parseUntypedTextMultirange(src) + if err != nil { + return err + } + + err = multirange.SetLen(len(elements)) + if err != nil { + return err + } + + if len(elements) == 0 { + return nil + } + + elementScanPlan := c.ElementType.Codec.PlanScan(m, c.ElementType.OID, TextFormatCode, multirange.ScanIndex(0)) + if elementScanPlan == nil { + elementScanPlan = m.PlanScan(c.ElementType.OID, TextFormatCode, multirange.ScanIndex(0)) + } + + for i, s := range elements { + elem := multirange.ScanIndex(i) + err = elementScanPlan.Scan([]byte(s), elem) + if err != nil { + return err + } + } + + return nil +} + +type scanPlanMultirangeCodec struct { + multirangeCodec *MultirangeCodec + m *Map + oid uint32 + formatCode int16 + elementScanPlan ScanPlan +} + +func (spac *scanPlanMultirangeCodec) Scan(src []byte, dst any) error { + c := spac.multirangeCodec + m := spac.m + oid := spac.oid + formatCode := spac.formatCode + + multirange := dst.(MultirangeSetter) + + if src == nil { + return multirange.ScanNull() + } + + switch formatCode { + case BinaryFormatCode: + return c.decodeBinary(m, oid, src, multirange) + case TextFormatCode: + return c.decodeText(m, oid, src, multirange) + default: + return fmt.Errorf("unknown format code %d", formatCode) + } +} + +func (c *MultirangeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + switch format { + case TextFormatCode: + return string(src), nil + case BinaryFormatCode: + buf := make([]byte, len(src)) + copy(buf, src) + return buf, nil + default: + return nil, fmt.Errorf("unknown format code %d", format) + } +} + +func (c *MultirangeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var multirange Multirange[Range[any]] + err := m.PlanScan(oid, format, &multirange).Scan(src, &multirange) + return multirange, err +} + +func parseUntypedTextMultirange(src []byte) ([]string, error) { + elements := make([]string, 0) + + buf := bytes.NewBuffer(src) + + skipWhitespace(buf) + + r, _, err := buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + if r != '{' { + return nil, fmt.Errorf("invalid multirange, expected '{': %v", err) + } + +parseValueLoop: + for { + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid multirange: %v", err) + } + + switch r { + case ',': // skip range separator + case '}': + break parseValueLoop + default: + buf.UnreadRune() + value, err := parseRange(buf) + if err != nil { + return nil, fmt.Errorf("invalid multirange value: %v", err) + } + elements = append(elements, value) + } + } + + skipWhitespace(buf) + + if buf.Len() > 0 { + return nil, fmt.Errorf("unexpected trailing data: %v", buf.String()) + } + + return elements, nil + +} + +func parseRange(buf *bytes.Buffer) (string, error) { + s := &bytes.Buffer{} + + boundSepRead := false + for { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + + switch r { + case ',', '}': + if r == ',' && !boundSepRead { + boundSepRead = true + break + } + buf.UnreadRune() + return s.String(), nil + } + + s.WriteRune(r) + } +} + +// Multirange is a generic multirange type. +// +// T should implement RangeValuer and *T should implement RangeScanner. However, there does not appear to be a way to +// enforce the RangeScanner constraint. +type Multirange[T RangeValuer] []T + +func (r Multirange[T]) IsNull() bool { + return r == nil +} + +func (r Multirange[T]) Len() int { + return len(r) +} + +func (r Multirange[T]) Index(i int) any { + return r[i] +} + +func (r Multirange[T]) IndexType() any { + var zero T + return zero +} + +func (r *Multirange[T]) ScanNull() error { + *r = nil + return nil +} + +func (r *Multirange[T]) SetLen(n int) error { + *r = make([]T, n) + return nil +} + +func (r Multirange[T]) ScanIndex(i int) any { + return &r[i] +} + +func (r Multirange[T]) ScanIndexType() any { + return new(T) +} diff --git a/pgtype/multirange_test.go b/pgtype/multirange_test.go new file mode 100644 index 00000000..77273e59 --- /dev/null +++ b/pgtype/multirange_test.go @@ -0,0 +1,114 @@ +package pgtype_test + +import ( + "context" + "reflect" + "testing" + + pgx "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/require" +) + +func TestMultirangeCodecTranscode(t *testing.T) { + skipPostgreSQLVersionLessThan(t, 14) + skipCockroachDB(t, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)") + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int4multirange", []pgxtest.ValueRoundTripTest{ + { + pgtype.Multirange[pgtype.Range[pgtype.Int4]](nil), + new(pgtype.Multirange[pgtype.Range[pgtype.Int4]]), + func(a any) bool { return reflect.DeepEqual(pgtype.Multirange[pgtype.Range[pgtype.Int4]](nil), a) }, + }, + { + pgtype.Multirange[pgtype.Range[pgtype.Int4]]{}, + new(pgtype.Multirange[pgtype.Range[pgtype.Int4]]), + func(a any) bool { return reflect.DeepEqual(pgtype.Multirange[pgtype.Range[pgtype.Int4]]{}, a) }, + }, + { + pgtype.Multirange[pgtype.Range[pgtype.Int4]]{ + { + Lower: pgtype.Int4{Int32: 1, Valid: true}, + Upper: pgtype.Int4{Int32: 5, Valid: true}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + { + Lower: pgtype.Int4{Int32: 7, Valid: true}, + Upper: pgtype.Int4{Int32: 9, Valid: true}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + }, + new(pgtype.Multirange[pgtype.Range[pgtype.Int4]]), + func(a any) bool { + return reflect.DeepEqual(pgtype.Multirange[pgtype.Range[pgtype.Int4]]{ + { + Lower: pgtype.Int4{Int32: 1, Valid: true}, + Upper: pgtype.Int4{Int32: 5, Valid: true}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + { + Lower: pgtype.Int4{Int32: 7, Valid: true}, + Upper: pgtype.Int4{Int32: 9, Valid: true}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + }, a) + }, + }, + }) +} + +func TestMultirangeCodecDecodeValue(t *testing.T) { + skipPostgreSQLVersionLessThan(t, 14) + skipCockroachDB(t, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)") + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + + for _, tt := range []struct { + sql string + expected any + }{ + { + sql: `select int4multirange(int4range(1, 5), int4range(7,9))`, + expected: pgtype.Multirange[pgtype.Range[any]]{ + { + Lower: int32(1), + Upper: int32(5), + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + { + Lower: int32(7), + Upper: int32(9), + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + }, + }, + } { + t.Run(tt.sql, func(t *testing.T) { + rows, err := conn.Query(ctx, tt.sql) + require.NoError(t, err) + + for rows.Next() { + values, err := rows.Values() + require.NoError(t, err) + require.Len(t, values, 1) + require.Equal(t, tt.expected, values[0]) + } + + require.NoError(t, rows.Err()) + }) + } + }) +} diff --git a/pgtype/numeric.go b/pgtype/numeric.go new file mode 100644 index 00000000..a5f4ed3a --- /dev/null +++ b/pgtype/numeric.go @@ -0,0 +1,801 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "math/big" + "strconv" + "strings" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +// PostgreSQL internal numeric storage uses 16-bit "digits" with base of 10,000 +const nbase = 10000 + +const ( + pgNumericNaN = 0x00000000c0000000 + pgNumericNaNSign = 0xc000 + + pgNumericPosInf = 0x00000000d0000000 + pgNumericPosInfSign = 0xd000 + + pgNumericNegInf = 0x00000000f0000000 + pgNumericNegInfSign = 0xf000 +) + +var big0 *big.Int = big.NewInt(0) +var big1 *big.Int = big.NewInt(1) +var big10 *big.Int = big.NewInt(10) +var big100 *big.Int = big.NewInt(100) +var big1000 *big.Int = big.NewInt(1000) + +var bigMaxInt8 *big.Int = big.NewInt(math.MaxInt8) +var bigMinInt8 *big.Int = big.NewInt(math.MinInt8) +var bigMaxInt16 *big.Int = big.NewInt(math.MaxInt16) +var bigMinInt16 *big.Int = big.NewInt(math.MinInt16) +var bigMaxInt32 *big.Int = big.NewInt(math.MaxInt32) +var bigMinInt32 *big.Int = big.NewInt(math.MinInt32) +var bigMaxInt64 *big.Int = big.NewInt(math.MaxInt64) +var bigMinInt64 *big.Int = big.NewInt(math.MinInt64) +var bigMaxInt *big.Int = big.NewInt(int64(maxInt)) +var bigMinInt *big.Int = big.NewInt(int64(minInt)) + +var bigMaxUint8 *big.Int = big.NewInt(math.MaxUint8) +var bigMaxUint16 *big.Int = big.NewInt(math.MaxUint16) +var bigMaxUint32 *big.Int = big.NewInt(math.MaxUint32) +var bigMaxUint64 *big.Int = (&big.Int{}).SetUint64(uint64(math.MaxUint64)) +var bigMaxUint *big.Int = (&big.Int{}).SetUint64(uint64(maxUint)) + +var bigNBase *big.Int = big.NewInt(nbase) +var bigNBaseX2 *big.Int = big.NewInt(nbase * nbase) +var bigNBaseX3 *big.Int = big.NewInt(nbase * nbase * nbase) +var bigNBaseX4 *big.Int = big.NewInt(nbase * nbase * nbase * nbase) + +type NumericScanner interface { + ScanNumeric(v Numeric) error +} + +type NumericValuer interface { + NumericValue() (Numeric, error) +} + +type Numeric struct { + Int *big.Int + Exp int32 + NaN bool + InfinityModifier InfinityModifier + Valid bool +} + +func (n *Numeric) ScanNumeric(v Numeric) error { + *n = v + return nil +} + +func (n Numeric) NumericValue() (Numeric, error) { + return n, nil +} + +func (n Numeric) Float64Value() (Float8, error) { + if !n.Valid { + return Float8{}, nil + } else if n.NaN { + return Float8{Float64: math.NaN(), Valid: true}, nil + } else if n.InfinityModifier == Infinity { + return Float8{Float64: math.Inf(1), Valid: true}, nil + } else if n.InfinityModifier == NegativeInfinity { + return Float8{Float64: math.Inf(-1), Valid: true}, nil + } + + buf := make([]byte, 0, 32) + + if n.Int == nil { + buf = append(buf, '0') + } else { + buf = append(buf, n.Int.String()...) + } + buf = append(buf, 'e') + buf = append(buf, strconv.FormatInt(int64(n.Exp), 10)...) + + f, err := strconv.ParseFloat(string(buf), 64) + if err != nil { + return Float8{}, err + } + + return Float8{Float64: f, Valid: true}, nil +} + +func (n *Numeric) ScanInt64(v Int8) error { + if !v.Valid { + *n = Numeric{} + return nil + } + + *n = Numeric{Int: big.NewInt(v.Int64), Valid: true} + return nil +} + +func (n Numeric) Int64Value() (Int8, error) { + if !n.Valid { + return Int8{}, nil + } + + bi, err := n.toBigInt() + if err != nil { + return Int8{}, err + } + + if !bi.IsInt64() { + return Int8{}, fmt.Errorf("cannot convert %v to int64", n) + } + + return Int8{Int64: bi.Int64(), Valid: true}, nil +} + +func (n *Numeric) toBigInt() (*big.Int, error) { + if n.Exp == 0 { + return n.Int, nil + } + + num := &big.Int{} + num.Set(n.Int) + if n.Exp > 0 { + mul := &big.Int{} + mul.Exp(big10, big.NewInt(int64(n.Exp)), nil) + num.Mul(num, mul) + return num, nil + } + + div := &big.Int{} + div.Exp(big10, big.NewInt(int64(-n.Exp)), nil) + remainder := &big.Int{} + num.DivMod(num, div, remainder) + if remainder.Cmp(big0) != 0 { + return nil, fmt.Errorf("cannot convert %v to integer", n) + } + return num, nil +} + +func parseNumericString(str string) (n *big.Int, exp int32, err error) { + parts := strings.SplitN(str, ".", 2) + digits := strings.Join(parts, "") + + if len(parts) > 1 { + exp = int32(-len(parts[1])) + } else { + for len(digits) > 1 && digits[len(digits)-1] == '0' && digits[len(digits)-2] != '-' { + digits = digits[:len(digits)-1] + exp++ + } + } + + accum := &big.Int{} + if _, ok := accum.SetString(digits, 10); !ok { + return nil, 0, fmt.Errorf("%s is not a number", str) + } + + return accum, exp, nil +} + +func nbaseDigitsToInt64(src []byte) (accum int64, bytesRead, digitsRead int) { + digits := len(src) / 2 + if digits > 4 { + digits = 4 + } + + rp := 0 + + for i := 0; i < digits; i++ { + if i > 0 { + accum *= nbase + } + accum += int64(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + } + + return accum, rp, digits +} + +// Scan implements the database/sql Scanner interface. +func (n *Numeric) Scan(src any) error { + if src == nil { + *n = Numeric{} + return nil + } + + switch src := src.(type) { + case string: + return scanPlanTextAnyToNumericScanner{}.Scan([]byte(src), n) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (n Numeric) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + + buf, err := NumericCodec{}.PlanEncode(nil, 0, TextFormatCode, n).Encode(n, nil) + if err != nil { + return nil, err + } + return string(buf), err +} + +func (n Numeric) MarshalJSON() ([]byte, error) { + if !n.Valid { + return []byte("null"), nil + } + + if n.NaN { + return []byte(`"NaN"`), nil + } + + return n.numberTextBytes(), nil +} + +// numberString returns a string of the number. undefined if NaN, infinite, or NULL +func (n Numeric) numberTextBytes() []byte { + intStr := n.Int.String() + buf := &bytes.Buffer{} + exp := int(n.Exp) + if exp > 0 { + buf.WriteString(intStr) + for i := 0; i < exp; i++ { + buf.WriteByte('0') + } + } else if exp < 0 { + if len(intStr) <= -exp { + buf.WriteString("0.") + leadingZeros := -exp - len(intStr) + for i := 0; i < leadingZeros; i++ { + buf.WriteByte('0') + } + buf.WriteString(intStr) + } else if len(intStr) > -exp { + dpPos := len(intStr) + exp + buf.WriteString(intStr[:dpPos]) + buf.WriteByte('.') + buf.WriteString(intStr[dpPos:]) + } + } else { + buf.WriteString(intStr) + } + + return buf.Bytes() +} + +type NumericCodec struct{} + +func (NumericCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (NumericCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (NumericCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch format { + case BinaryFormatCode: + switch value.(type) { + case NumericValuer: + return encodePlanNumericCodecBinaryNumericValuer{} + case Float64Valuer: + return encodePlanNumericCodecBinaryFloat64Valuer{} + case Int64Valuer: + return encodePlanNumericCodecBinaryInt64Valuer{} + } + case TextFormatCode: + switch value.(type) { + case NumericValuer: + return encodePlanNumericCodecTextNumericValuer{} + case Float64Valuer: + return encodePlanNumericCodecTextFloat64Valuer{} + case Int64Valuer: + return encodePlanNumericCodecTextInt64Valuer{} + } + } + + return nil +} + +type encodePlanNumericCodecBinaryNumericValuer struct{} + +func (encodePlanNumericCodecBinaryNumericValuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + n, err := value.(NumericValuer).NumericValue() + if err != nil { + return nil, err + } + + return encodeNumericBinary(n, buf) +} + +type encodePlanNumericCodecBinaryFloat64Valuer struct{} + +func (encodePlanNumericCodecBinaryFloat64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + n, err := value.(Float64Valuer).Float64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + if math.IsNaN(n.Float64) { + return encodeNumericBinary(Numeric{NaN: true, Valid: true}, buf) + } else if math.IsInf(n.Float64, 1) { + return encodeNumericBinary(Numeric{InfinityModifier: Infinity, Valid: true}, buf) + } else if math.IsInf(n.Float64, -1) { + return encodeNumericBinary(Numeric{InfinityModifier: NegativeInfinity, Valid: true}, buf) + } + num, exp, err := parseNumericString(strconv.FormatFloat(n.Float64, 'f', -1, 64)) + if err != nil { + return nil, err + } + + return encodeNumericBinary(Numeric{Int: num, Exp: exp, Valid: true}, buf) +} + +type encodePlanNumericCodecBinaryInt64Valuer struct{} + +func (encodePlanNumericCodecBinaryInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + return encodeNumericBinary(Numeric{Int: big.NewInt(n.Int64), Valid: true}, buf) +} + +func encodeNumericBinary(n Numeric, buf []byte) (newBuf []byte, err error) { + if !n.Valid { + return nil, nil + } + + if n.NaN { + buf = pgio.AppendUint64(buf, pgNumericNaN) + return buf, nil + } else if n.InfinityModifier == Infinity { + buf = pgio.AppendUint64(buf, pgNumericPosInf) + return buf, nil + } else if n.InfinityModifier == NegativeInfinity { + buf = pgio.AppendUint64(buf, pgNumericNegInf) + return buf, nil + } + + var sign int16 + if n.Int.Cmp(big0) < 0 { + sign = 16384 + } + + absInt := &big.Int{} + wholePart := &big.Int{} + fracPart := &big.Int{} + remainder := &big.Int{} + absInt.Abs(n.Int) + + // Normalize absInt and exp to where exp is always a multiple of 4. This makes + // converting to 16-bit base 10,000 digits easier. + var exp int32 + switch n.Exp % 4 { + case 1, -3: + exp = n.Exp - 1 + absInt.Mul(absInt, big10) + case 2, -2: + exp = n.Exp - 2 + absInt.Mul(absInt, big100) + case 3, -1: + exp = n.Exp - 3 + absInt.Mul(absInt, big1000) + default: + exp = n.Exp + } + + if exp < 0 { + divisor := &big.Int{} + divisor.Exp(big10, big.NewInt(int64(-exp)), nil) + wholePart.DivMod(absInt, divisor, fracPart) + fracPart.Add(fracPart, divisor) + } else { + wholePart = absInt + } + + var wholeDigits, fracDigits []int16 + + for wholePart.Cmp(big0) != 0 { + wholePart.DivMod(wholePart, bigNBase, remainder) + wholeDigits = append(wholeDigits, int16(remainder.Int64())) + } + + if fracPart.Cmp(big0) != 0 { + for fracPart.Cmp(big1) != 0 { + fracPart.DivMod(fracPart, bigNBase, remainder) + fracDigits = append(fracDigits, int16(remainder.Int64())) + } + } + + buf = pgio.AppendInt16(buf, int16(len(wholeDigits)+len(fracDigits))) + + var weight int16 + if len(wholeDigits) > 0 { + weight = int16(len(wholeDigits) - 1) + if exp > 0 { + weight += int16(exp / 4) + } + } else { + weight = int16(exp/4) - 1 + int16(len(fracDigits)) + } + buf = pgio.AppendInt16(buf, weight) + + buf = pgio.AppendInt16(buf, sign) + + var dscale int16 + if n.Exp < 0 { + dscale = int16(-n.Exp) + } + buf = pgio.AppendInt16(buf, dscale) + + for i := len(wholeDigits) - 1; i >= 0; i-- { + buf = pgio.AppendInt16(buf, wholeDigits[i]) + } + + for i := len(fracDigits) - 1; i >= 0; i-- { + buf = pgio.AppendInt16(buf, fracDigits[i]) + } + + return buf, nil +} + +type encodePlanNumericCodecTextNumericValuer struct{} + +func (encodePlanNumericCodecTextNumericValuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + n, err := value.(NumericValuer).NumericValue() + if err != nil { + return nil, err + } + + return encodeNumericText(n, buf) +} + +type encodePlanNumericCodecTextFloat64Valuer struct{} + +func (encodePlanNumericCodecTextFloat64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + n, err := value.(Float64Valuer).Float64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + if math.IsNaN(n.Float64) { + buf = append(buf, "NaN"...) + } else if math.IsInf(n.Float64, 1) { + buf = append(buf, "Infinity"...) + } else if math.IsInf(n.Float64, -1) { + buf = append(buf, "-Infinity"...) + } else { + buf = append(buf, strconv.FormatFloat(n.Float64, 'f', -1, 64)...) + } + return buf, nil +} + +type encodePlanNumericCodecTextInt64Valuer struct{} + +func (encodePlanNumericCodecTextInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + buf = append(buf, strconv.FormatInt(n.Int64, 10)...) + return buf, nil +} + +func encodeNumericText(n Numeric, buf []byte) (newBuf []byte, err error) { + if !n.Valid { + return nil, nil + } + + if n.NaN { + buf = append(buf, "NaN"...) + return buf, nil + } else if n.InfinityModifier == Infinity { + buf = append(buf, "Infinity"...) + return buf, nil + } else if n.InfinityModifier == NegativeInfinity { + buf = append(buf, "-Infinity"...) + return buf, nil + } + + buf = append(buf, n.numberTextBytes()...) + + return buf, nil +} + +func (NumericCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case NumericScanner: + return scanPlanBinaryNumericToNumericScanner{} + case Float64Scanner: + return scanPlanBinaryNumericToFloat64Scanner{} + case Int64Scanner: + return scanPlanBinaryNumericToInt64Scanner{} + case TextScanner: + return scanPlanBinaryNumericToTextScanner{} + } + case TextFormatCode: + switch target.(type) { + case NumericScanner: + return scanPlanTextAnyToNumericScanner{} + case Float64Scanner: + return scanPlanTextAnyToFloat64Scanner{} + case Int64Scanner: + return scanPlanTextAnyToInt64Scanner{} + } + } + + return nil +} + +type scanPlanBinaryNumericToNumericScanner struct{} + +func (scanPlanBinaryNumericToNumericScanner) Scan(src []byte, dst any) error { + scanner := (dst).(NumericScanner) + + if src == nil { + return scanner.ScanNumeric(Numeric{}) + } + + if len(src) < 8 { + return fmt.Errorf("numeric incomplete %v", src) + } + + rp := 0 + ndigits := binary.BigEndian.Uint16(src[rp:]) + rp += 2 + weight := int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + sign := binary.BigEndian.Uint16(src[rp:]) + rp += 2 + dscale := int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + + if sign == pgNumericNaNSign { + return scanner.ScanNumeric(Numeric{NaN: true, Valid: true}) + } else if sign == pgNumericPosInfSign { + return scanner.ScanNumeric(Numeric{InfinityModifier: Infinity, Valid: true}) + } else if sign == pgNumericNegInfSign { + return scanner.ScanNumeric(Numeric{InfinityModifier: NegativeInfinity, Valid: true}) + } + + if ndigits == 0 { + return scanner.ScanNumeric(Numeric{Int: big.NewInt(0), Valid: true}) + } + + if len(src[rp:]) < int(ndigits)*2 { + return fmt.Errorf("numeric incomplete %v", src) + } + + accum := &big.Int{} + + for i := 0; i < int(ndigits+3)/4; i++ { + int64accum, bytesRead, digitsRead := nbaseDigitsToInt64(src[rp:]) + rp += bytesRead + + if i > 0 { + var mul *big.Int + switch digitsRead { + case 1: + mul = bigNBase + case 2: + mul = bigNBaseX2 + case 3: + mul = bigNBaseX3 + case 4: + mul = bigNBaseX4 + default: + return fmt.Errorf("invalid digitsRead: %d (this can't happen)", digitsRead) + } + accum.Mul(accum, mul) + } + + accum.Add(accum, big.NewInt(int64accum)) + } + + exp := (int32(weight) - int32(ndigits) + 1) * 4 + + if dscale > 0 { + fracNBaseDigits := int16(int32(ndigits) - int32(weight) - 1) + fracDecimalDigits := fracNBaseDigits * 4 + + if dscale > fracDecimalDigits { + multCount := int(dscale - fracDecimalDigits) + for i := 0; i < multCount; i++ { + accum.Mul(accum, big10) + exp-- + } + } else if dscale < fracDecimalDigits { + divCount := int(fracDecimalDigits - dscale) + for i := 0; i < divCount; i++ { + accum.Div(accum, big10) + exp++ + } + } + } + + reduced := &big.Int{} + remainder := &big.Int{} + if exp >= 0 { + for { + reduced.DivMod(accum, big10, remainder) + if remainder.Cmp(big0) != 0 { + break + } + accum.Set(reduced) + exp++ + } + } + + if sign != 0 { + accum.Neg(accum) + } + + return scanner.ScanNumeric(Numeric{Int: accum, Exp: exp, Valid: true}) +} + +type scanPlanBinaryNumericToFloat64Scanner struct{} + +func (scanPlanBinaryNumericToFloat64Scanner) Scan(src []byte, dst any) error { + scanner := (dst).(Float64Scanner) + + if src == nil { + return scanner.ScanFloat64(Float8{}) + } + + var n Numeric + + err := scanPlanBinaryNumericToNumericScanner{}.Scan(src, &n) + if err != nil { + return err + } + + f8, err := n.Float64Value() + if err != nil { + return err + } + + return scanner.ScanFloat64(f8) +} + +type scanPlanBinaryNumericToInt64Scanner struct{} + +func (scanPlanBinaryNumericToInt64Scanner) Scan(src []byte, dst any) error { + scanner := (dst).(Int64Scanner) + + if src == nil { + return scanner.ScanInt64(Int8{}) + } + + var n Numeric + + err := scanPlanBinaryNumericToNumericScanner{}.Scan(src, &n) + if err != nil { + return err + } + + bigInt, err := n.toBigInt() + if err != nil { + return err + } + + if !bigInt.IsInt64() { + return fmt.Errorf("%v is out of range for int64", bigInt) + } + + return scanner.ScanInt64(Int8{Int64: bigInt.Int64(), Valid: true}) +} + +type scanPlanBinaryNumericToTextScanner struct{} + +func (scanPlanBinaryNumericToTextScanner) Scan(src []byte, dst any) error { + scanner := (dst).(TextScanner) + + if src == nil { + return scanner.ScanText(Text{}) + } + + var n Numeric + + err := scanPlanBinaryNumericToNumericScanner{}.Scan(src, &n) + if err != nil { + return err + } + + sbuf, err := encodeNumericText(n, nil) + if err != nil { + return err + } + + return scanner.ScanText(Text{String: string(sbuf), Valid: true}) +} + +type scanPlanTextAnyToNumericScanner struct{} + +func (scanPlanTextAnyToNumericScanner) Scan(src []byte, dst any) error { + scanner := (dst).(NumericScanner) + + if src == nil { + return scanner.ScanNumeric(Numeric{}) + } + + if string(src) == "NaN" { + return scanner.ScanNumeric(Numeric{NaN: true, Valid: true}) + } else if string(src) == "Infinity" { + return scanner.ScanNumeric(Numeric{InfinityModifier: Infinity, Valid: true}) + } else if string(src) == "-Infinity" { + return scanner.ScanNumeric(Numeric{InfinityModifier: NegativeInfinity, Valid: true}) + } + + num, exp, err := parseNumericString(string(src)) + if err != nil { + return err + } + + return scanner.ScanNumeric(Numeric{Int: num, Exp: exp, Valid: true}) +} + +func (c NumericCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + if format == TextFormatCode { + return string(src), nil + } + + var n Numeric + err := codecScan(c, m, oid, format, src, &n) + if err != nil { + return nil, err + } + + buf, err := m.Encode(oid, TextFormatCode, n, nil) + if err != nil { + return nil, err + } + return string(buf), nil +} + +func (c NumericCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var n Numeric + err := codecScan(c, m, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil +} diff --git a/pgtype/numeric_test.go b/pgtype/numeric_test.go new file mode 100644 index 00000000..071f0c24 --- /dev/null +++ b/pgtype/numeric_test.go @@ -0,0 +1,224 @@ +package pgtype_test + +import ( + "context" + "encoding/json" + "math" + "math/big" + "math/rand" + "strconv" + "testing" + + pgx "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func mustParseBigInt(t *testing.T, src string) *big.Int { + i := &big.Int{} + if _, ok := i.SetString(src, 10); !ok { + t.Fatalf("could not parse big.Int: %s", src) + } + return i +} + +func isExpectedEqNumeric(a any) func(any) bool { + return func(v any) bool { + aa := a.(pgtype.Numeric) + vv := v.(pgtype.Numeric) + + if aa.Valid != vv.Valid { + return false + } + + // If NULL doesn't matter what the rest of the values are. + if !aa.Valid { + return true + } + + if !(aa.NaN == vv.NaN && aa.InfinityModifier == vv.InfinityModifier) { + return false + } + + // If NaN or InfinityModifier are set then Int and Exp don't matter. + if aa.NaN || aa.InfinityModifier != pgtype.Finite { + return true + } + + aaInt := (&big.Int{}).Set(aa.Int) + vvInt := (&big.Int{}).Set(vv.Int) + + if aa.Exp < vv.Exp { + mul := (&big.Int{}).Exp(big.NewInt(10), big.NewInt(int64(vv.Exp-aa.Exp)), nil) + vvInt.Mul(vvInt, mul) + } else if aa.Exp > vv.Exp { + mul := (&big.Int{}).Exp(big.NewInt(10), big.NewInt(int64(aa.Exp-vv.Exp)), nil) + aaInt.Mul(aaInt, mul) + } + + return aaInt.Cmp(vvInt) == 0 + } +} + +func mustParseNumeric(t *testing.T, src string) pgtype.Numeric { + var n pgtype.Numeric + plan := pgtype.NumericCodec{}.PlanScan(nil, pgtype.NumericOID, pgtype.TextFormatCode, &n) + require.NotNil(t, plan) + err := plan.Scan([]byte(src), &n) + require.NoError(t, err) + return n +} + +func TestNumericCodec(t *testing.T) { + skipCockroachDB(t, "server formats numeric text format differently") + + max := new(big.Int).Exp(big.NewInt(10), big.NewInt(147454), nil) + max.Add(max, big.NewInt(1)) + longestNumeric := pgtype.Numeric{Int: max, Exp: -16383, Valid: true} + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "numeric", []pgxtest.ValueRoundTripTest{ + {mustParseNumeric(t, "1"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "1"))}, + {mustParseNumeric(t, "3.14159"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "3.14159"))}, + {mustParseNumeric(t, "100010001"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "100010001"))}, + {mustParseNumeric(t, "100010001.0001"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "100010001.0001"))}, + {mustParseNumeric(t, "4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981"))}, + {mustParseNumeric(t, "0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234"))}, + {mustParseNumeric(t, "0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123"))}, + {pgtype.Numeric{Int: mustParseBigInt(t, "243723409723490243842378942378901237502734019231380123"), Exp: 23790, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "243723409723490243842378942378901237502734019231380123"), Exp: 23790, Valid: true})}, + {pgtype.Numeric{Int: mustParseBigInt(t, "2437"), Exp: 23790, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "2437"), Exp: 23790, Valid: true})}, + {pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 80, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 80, Valid: true})}, + {pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 81, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 81, Valid: true})}, + {pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 82, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 82, Valid: true})}, + {pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 83, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 83, Valid: true})}, + {pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 84, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 84, Valid: true})}, + {pgtype.Numeric{Int: mustParseBigInt(t, "913423409823409243892349028349023482934092340892390101"), Exp: -14021, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "913423409823409243892349028349023482934092340892390101"), Exp: -14021, Valid: true})}, + {pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -90, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -90, Valid: true})}, + {pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -91, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -91, Valid: true})}, + {pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -92, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -92, Valid: true})}, + {pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -93, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -93, Valid: true})}, + {pgtype.Numeric{NaN: true, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{NaN: true, Valid: true})}, + {longestNumeric, new(pgtype.Numeric), isExpectedEqNumeric(longestNumeric)}, + {mustParseNumeric(t, "1"), new(int64), isExpectedEq(int64(1))}, + {math.NaN(), new(float64), func(a any) bool { return math.IsNaN(a.(float64)) }}, + {float32(math.NaN()), new(float32), func(a any) bool { return math.IsNaN(float64(a.(float32))) }}, + {int64(-1), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "-1"))}, + {int64(0), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "0"))}, + {int64(1), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "1"))}, + {int64(math.MinInt64), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, strconv.FormatInt(math.MinInt64, 10)))}, + {int64(math.MinInt64 + 1), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, strconv.FormatInt(math.MinInt64+1, 10)))}, + {int64(math.MaxInt64), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, strconv.FormatInt(math.MaxInt64, 10)))}, + {int64(math.MaxInt64 - 1), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, strconv.FormatInt(math.MaxInt64-1, 10)))}, + {"1.23", new(string), isExpectedEq("1.23")}, + {pgtype.Numeric{}, new(pgtype.Numeric), isExpectedEq(pgtype.Numeric{})}, + {nil, new(pgtype.Numeric), isExpectedEq(pgtype.Numeric{})}, + {mustParseNumeric(t, "1"), new(string), isExpectedEq("1")}, + {pgtype.Numeric{NaN: true, Valid: true}, new(string), isExpectedEq("NaN")}, + }) + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int8", []pgxtest.ValueRoundTripTest{ + {mustParseNumeric(t, "-1"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "-1"))}, + {mustParseNumeric(t, "0"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "0"))}, + {mustParseNumeric(t, "1"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "1"))}, + }) +} + +func TestNumericCodecInfinity(t *testing.T) { + skipCockroachDB(t, "server formats numeric text format differently") + skipPostgreSQLVersionLessThan(t, 14) + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "numeric", []pgxtest.ValueRoundTripTest{ + {math.Inf(1), new(float64), isExpectedEq(math.Inf(1))}, + {float32(math.Inf(1)), new(float32), isExpectedEq(float32(math.Inf(1)))}, + {math.Inf(-1), new(float64), isExpectedEq(math.Inf(-1))}, + {float32(math.Inf(-1)), new(float32), isExpectedEq(float32(math.Inf(-1)))}, + {pgtype.Numeric{InfinityModifier: pgtype.Infinity, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{InfinityModifier: pgtype.Infinity, Valid: true})}, + {pgtype.Numeric{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{InfinityModifier: pgtype.NegativeInfinity, Valid: true})}, + {pgtype.Numeric{InfinityModifier: pgtype.Infinity, Valid: true}, new(string), isExpectedEq("Infinity")}, + {pgtype.Numeric{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, new(string), isExpectedEq("-Infinity")}, + }) +} + +func TestNumericFloat64Valuer(t *testing.T) { + for i, tt := range []struct { + n pgtype.Numeric + f pgtype.Float8 + }{ + {mustParseNumeric(t, "1"), pgtype.Float8{Float64: 1, Valid: true}}, + {mustParseNumeric(t, "0.0000000000000000001"), pgtype.Float8{Float64: 0.0000000000000000001, Valid: true}}, + {mustParseNumeric(t, "-99999999999"), pgtype.Float8{Float64: -99999999999, Valid: true}}, + {pgtype.Numeric{InfinityModifier: pgtype.Infinity, Valid: true}, pgtype.Float8{Float64: math.Inf(1), Valid: true}}, + {pgtype.Numeric{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, pgtype.Float8{Float64: math.Inf(-1), Valid: true}}, + {pgtype.Numeric{Valid: true}, pgtype.Float8{Valid: true}}, + {pgtype.Numeric{}, pgtype.Float8{}}, + } { + f, err := tt.n.Float64Value() + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.f, f, "%d", i) + } + + f, err := pgtype.Numeric{NaN: true, Valid: true}.Float64Value() + assert.NoError(t, err) + assert.True(t, math.IsNaN(f.Float64)) + assert.True(t, f.Valid) +} + +func TestNumericCodecFuzz(t *testing.T) { + skipCockroachDB(t, "server formats numeric text format differently") + + r := rand.New(rand.NewSource(0)) + max := &big.Int{} + max.SetString("9999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999", 10) + + tests := make([]pgxtest.ValueRoundTripTest, 0, 2000) + for i := 0; i < 10; i++ { + for j := -50; j < 50; j++ { + num := (&big.Int{}).Rand(r, max) + + n := pgtype.Numeric{Int: num, Exp: int32(j), Valid: true} + tests = append(tests, pgxtest.ValueRoundTripTest{n, new(pgtype.Numeric), isExpectedEqNumeric(n)}) + + negNum := &big.Int{} + negNum.Neg(num) + n = pgtype.Numeric{Int: negNum, Exp: int32(j), Valid: true} + tests = append(tests, pgxtest.ValueRoundTripTest{n, new(pgtype.Numeric), isExpectedEqNumeric(n)}) + } + } + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "numeric", tests) +} + +func TestNumericMarshalJSON(t *testing.T) { + skipCockroachDB(t, "server formats numeric text format differently") + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + + for i, tt := range []struct { + decString string + }{ + {"NaN"}, + {"0"}, + {"1"}, + {"-1"}, + {"1000000000000000000"}, + {"1234.56789"}, + {"1.56789"}, + {"0.00000000000056789"}, + {"0.00123000"}, + {"123e-3"}, + {"243723409723490243842378942378901237502734019231380123e23790"}, + {"3409823409243892349028349023482934092340892390101e-14021"}, + } { + var num pgtype.Numeric + var pgJSON string + err := conn.QueryRow(ctx, `select $1::numeric, to_json($1::numeric)`, tt.decString).Scan(&num, &pgJSON) + require.NoErrorf(t, err, "%d", i) + + goJSON, err := json.Marshal(num) + require.NoErrorf(t, err, "%d", i) + + require.Equal(t, pgJSON, string(goJSON)) + } + }) +} diff --git a/pgtype/path.go b/pgtype/path.go new file mode 100644 index 00000000..73e0ec52 --- /dev/null +++ b/pgtype/path.go @@ -0,0 +1,272 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + "strings" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type PathScanner interface { + ScanPath(v Path) error +} + +type PathValuer interface { + PathValue() (Path, error) +} + +type Path struct { + P []Vec2 + Closed bool + Valid bool +} + +func (path *Path) ScanPath(v Path) error { + *path = v + return nil +} + +func (path Path) PathValue() (Path, error) { + return path, nil +} + +// Scan implements the database/sql Scanner interface. +func (path *Path) Scan(src any) error { + if src == nil { + *path = Path{} + return nil + } + + switch src := src.(type) { + case string: + return scanPlanTextAnyToPathScanner{}.Scan([]byte(src), path) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (path Path) Value() (driver.Value, error) { + if !path.Valid { + return nil, nil + } + + buf, err := PathCodec{}.PlanEncode(nil, 0, TextFormatCode, path).Encode(path, nil) + if err != nil { + return nil, err + } + + return string(buf), err +} + +type PathCodec struct{} + +func (PathCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (PathCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (PathCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(PathValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return encodePlanPathCodecBinary{} + case TextFormatCode: + return encodePlanPathCodecText{} + } + + return nil +} + +type encodePlanPathCodecBinary struct{} + +func (encodePlanPathCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + path, err := value.(PathValuer).PathValue() + if err != nil { + return nil, err + } + + if !path.Valid { + return nil, nil + } + + var closeByte byte + if path.Closed { + closeByte = 1 + } + buf = append(buf, closeByte) + + buf = pgio.AppendInt32(buf, int32(len(path.P))) + + for _, p := range path.P { + buf = pgio.AppendUint64(buf, math.Float64bits(p.X)) + buf = pgio.AppendUint64(buf, math.Float64bits(p.Y)) + } + + return buf, nil +} + +type encodePlanPathCodecText struct{} + +func (encodePlanPathCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { + path, err := value.(PathValuer).PathValue() + if err != nil { + return nil, err + } + + if !path.Valid { + return nil, nil + } + + var startByte, endByte byte + if path.Closed { + startByte = '(' + endByte = ')' + } else { + startByte = '[' + endByte = ']' + } + buf = append(buf, startByte) + + for i, p := range path.P { + if i > 0 { + buf = append(buf, ',') + } + buf = append(buf, fmt.Sprintf(`(%s,%s)`, + strconv.FormatFloat(p.X, 'f', -1, 64), + strconv.FormatFloat(p.Y, 'f', -1, 64), + )...) + } + + buf = append(buf, endByte) + + return buf, nil +} + +func (PathCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case PathScanner: + return scanPlanBinaryPathToPathScanner{} + } + case TextFormatCode: + switch target.(type) { + case PathScanner: + return scanPlanTextAnyToPathScanner{} + } + } + + return nil +} + +type scanPlanBinaryPathToPathScanner struct{} + +func (scanPlanBinaryPathToPathScanner) Scan(src []byte, dst any) error { + scanner := (dst).(PathScanner) + + if src == nil { + return scanner.ScanPath(Path{}) + } + + if len(src) < 5 { + return fmt.Errorf("invalid length for Path: %v", len(src)) + } + + closed := src[0] == 1 + pointCount := int(binary.BigEndian.Uint32(src[1:])) + + rp := 5 + + if 5+pointCount*16 != len(src) { + return fmt.Errorf("invalid length for Path with %d points: %v", pointCount, len(src)) + } + + points := make([]Vec2, pointCount) + for i := 0; i < len(points); i++ { + x := binary.BigEndian.Uint64(src[rp:]) + rp += 8 + y := binary.BigEndian.Uint64(src[rp:]) + rp += 8 + points[i] = Vec2{math.Float64frombits(x), math.Float64frombits(y)} + } + + return scanner.ScanPath(Path{ + P: points, + Closed: closed, + Valid: true, + }) +} + +type scanPlanTextAnyToPathScanner struct{} + +func (scanPlanTextAnyToPathScanner) Scan(src []byte, dst any) error { + scanner := (dst).(PathScanner) + + if src == nil { + return scanner.ScanPath(Path{}) + } + + if len(src) < 7 { + return fmt.Errorf("invalid length for Path: %v", len(src)) + } + + closed := src[0] == '(' + points := make([]Vec2, 0) + + str := string(src[2:]) + + for { + end := strings.IndexByte(str, ',') + x, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1:] + end = strings.IndexByte(str, ')') + + y, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + points = append(points, Vec2{x, y}) + + if end+3 < len(str) { + str = str[end+3:] + } else { + break + } + } + + return scanner.ScanPath(Path{P: points, Closed: closed, Valid: true}) +} + +func (c PathCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) +} + +func (c PathCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var path Path + err := codecScan(c, m, oid, format, src, &path) + if err != nil { + return nil, err + } + return path, nil +} diff --git a/pgtype/path_test.go b/pgtype/path_test.go new file mode 100644 index 00000000..cfffd22a --- /dev/null +++ b/pgtype/path_test.go @@ -0,0 +1,76 @@ +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" +) + +func isExpectedEqPath(a any) func(any) bool { + return func(v any) bool { + ap := a.(pgtype.Path) + vp := v.(pgtype.Path) + + if !(ap.Valid == vp.Valid && ap.Closed == vp.Closed && len(ap.P) == len(vp.P)) { + return false + } + + for i := range ap.P { + if ap.P[i] != vp.P[i] { + return false + } + } + + return true + } +} + +func TestPathTranscode(t *testing.T) { + skipCockroachDB(t, "Server does not support type path") + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "path", []pgxtest.ValueRoundTripTest{ + { + pgtype.Path{ + P: []pgtype.Vec2{{3.14, 1.678901234}, {7.1, 5.234}}, + Closed: false, + Valid: true, + }, + new(pgtype.Path), + isExpectedEqPath(pgtype.Path{ + P: []pgtype.Vec2{{3.14, 1.678901234}, {7.1, 5.234}}, + Closed: false, + Valid: true, + }), + }, + { + pgtype.Path{ + P: []pgtype.Vec2{{3.14, 1.678}, {7.1, 5.234}, {23.1, 9.34}}, + Closed: true, + Valid: true, + }, + new(pgtype.Path), + isExpectedEqPath(pgtype.Path{ + P: []pgtype.Vec2{{3.14, 1.678}, {7.1, 5.234}, {23.1, 9.34}}, + Closed: true, + Valid: true, + }), + }, + { + pgtype.Path{ + P: []pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, + Closed: true, + Valid: true, + }, + new(pgtype.Path), + isExpectedEqPath(pgtype.Path{ + P: []pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, + Closed: true, + Valid: true, + }), + }, + {pgtype.Path{}, new(pgtype.Path), isExpectedEqPath(pgtype.Path{})}, + {nil, new(pgtype.Path), isExpectedEqPath(pgtype.Path{})}, + }) +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go new file mode 100644 index 00000000..f8ad2bf3 --- /dev/null +++ b/pgtype/pgtype.go @@ -0,0 +1,1934 @@ +package pgtype + +import ( + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "net" + "net/netip" + "reflect" + "time" +) + +// PostgreSQL oids for common types +const ( + BoolOID = 16 + ByteaOID = 17 + QCharOID = 18 + NameOID = 19 + Int8OID = 20 + Int2OID = 21 + Int4OID = 23 + TextOID = 25 + OIDOID = 26 + TIDOID = 27 + XIDOID = 28 + CIDOID = 29 + JSONOID = 114 + JSONArrayOID = 199 + PointOID = 600 + LsegOID = 601 + PathOID = 602 + BoxOID = 603 + PolygonOID = 604 + LineOID = 628 + LineArrayOID = 629 + CIDROID = 650 + CIDRArrayOID = 651 + Float4OID = 700 + Float8OID = 701 + CircleOID = 718 + CircleArrayOID = 719 + UnknownOID = 705 + MacaddrOID = 829 + InetOID = 869 + BoolArrayOID = 1000 + QCharArrayOID = 1003 + NameArrayOID = 1003 + Int2ArrayOID = 1005 + Int4ArrayOID = 1007 + TextArrayOID = 1009 + TIDArrayOID = 1010 + ByteaArrayOID = 1001 + XIDArrayOID = 1011 + CIDArrayOID = 1012 + BPCharArrayOID = 1014 + VarcharArrayOID = 1015 + Int8ArrayOID = 1016 + PointArrayOID = 1017 + LsegArrayOID = 1018 + PathArrayOID = 1019 + BoxArrayOID = 1020 + Float4ArrayOID = 1021 + Float8ArrayOID = 1022 + PolygonArrayOID = 1027 + OIDArrayOID = 1028 + ACLItemOID = 1033 + ACLItemArrayOID = 1034 + MacaddrArrayOID = 1040 + InetArrayOID = 1041 + BPCharOID = 1042 + VarcharOID = 1043 + DateOID = 1082 + TimeOID = 1083 + TimestampOID = 1114 + TimestampArrayOID = 1115 + DateArrayOID = 1182 + TimeArrayOID = 1183 + TimestamptzOID = 1184 + TimestamptzArrayOID = 1185 + IntervalOID = 1186 + IntervalArrayOID = 1187 + NumericArrayOID = 1231 + BitOID = 1560 + BitArrayOID = 1561 + VarbitOID = 1562 + VarbitArrayOID = 1563 + NumericOID = 1700 + RecordOID = 2249 + RecordArrayOID = 2287 + UUIDOID = 2950 + UUIDArrayOID = 2951 + JSONBOID = 3802 + JSONBArrayOID = 3807 + DaterangeOID = 3912 + DaterangeArrayOID = 3913 + Int4rangeOID = 3904 + Int4rangeArrayOID = 3905 + NumrangeOID = 3906 + NumrangeArrayOID = 3907 + TsrangeOID = 3908 + TsrangeArrayOID = 3909 + TstzrangeOID = 3910 + TstzrangeArrayOID = 3911 + Int8rangeOID = 3926 + Int8rangeArrayOID = 3927 + Int4multirangeOID = 4451 + NummultirangeOID = 4532 + TsmultirangeOID = 4533 + TstzmultirangeOID = 4534 + DatemultirangeOID = 4535 + Int8multirangeOID = 4536 + Int4multirangeArrayOID = 6150 + NummultirangeArrayOID = 6151 + TsmultirangeArrayOID = 6152 + TstzmultirangeArrayOID = 6153 + DatemultirangeArrayOID = 6155 + Int8multirangeArrayOID = 6157 +) + +type InfinityModifier int8 + +const ( + Infinity InfinityModifier = 1 + Finite InfinityModifier = 0 + NegativeInfinity InfinityModifier = -Infinity +) + +func (im InfinityModifier) String() string { + switch im { + case Finite: + return "finite" + case Infinity: + return "infinity" + case NegativeInfinity: + return "-infinity" + default: + return "invalid" + } +} + +// PostgreSQL format codes +const ( + TextFormatCode = 0 + BinaryFormatCode = 1 +) + +// A Codec converts between Go and PostgreSQL values. +type Codec interface { + // FormatSupported returns true if the format is supported. + FormatSupported(int16) bool + + // PreferredFormat returns the preferred format. + PreferredFormat() int16 + + // PlanEncode returns an EncodePlan for encoding value into PostgreSQL format for oid and format. If no plan can be + // found then nil is returned. + PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan + + // PlanScan returns a ScanPlan for scanning a PostgreSQL value into a destination with the same type as target. If + // no plan can be found then nil is returned. + PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan + + // DecodeDatabaseSQLValue returns src decoded into a value compatible with the sql.Scanner interface. + DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) + + // DecodeValue returns src decoded into its default format. + DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) +} + +type nullAssignmentError struct { + dst any +} + +func (e *nullAssignmentError) Error() string { + return fmt.Sprintf("cannot assign NULL to %T", e.dst) +} + +type Type struct { + Codec Codec + Name string + OID uint32 +} + +// Map is the mapping between PostgreSQL server types and Go type handling logic. It can encode values for +// transmission to a PostgreSQL server and scan received values. +type Map struct { + oidToType map[uint32]*Type + nameToType map[string]*Type + reflectTypeToName map[reflect.Type]string + oidToFormatCode map[uint32]int16 + + reflectTypeToType map[reflect.Type]*Type + + memoizedScanPlans map[uint32]map[reflect.Type][2]ScanPlan + + // TryWrapEncodePlanFuncs is a slice of functions that will wrap a value that cannot be encoded by the Codec. Every + // time a wrapper is found the PlanEncode method will be recursively called with the new value. This allows several layers of wrappers + // to be built up. There are default functions placed in this slice by NewMap(). In most cases these functions + // should run last. i.e. Additional functions should typically be prepended not appended. + TryWrapEncodePlanFuncs []TryWrapEncodePlanFunc + + // TryWrapScanPlanFuncs is a slice of functions that will wrap a target that cannot be scanned into by the Codec. Every + // time a wrapper is found the PlanScan method will be recursively called with the new target. This allows several layers of wrappers + // to be built up. There are default functions placed in this slice by NewMap(). In most cases these functions + // should run last. i.e. Additional functions should typically be prepended not appended. + TryWrapScanPlanFuncs []TryWrapScanPlanFunc +} + +func NewMap() *Map { + m := &Map{ + oidToType: make(map[uint32]*Type), + nameToType: make(map[string]*Type), + reflectTypeToName: make(map[reflect.Type]string), + oidToFormatCode: make(map[uint32]int16), + + memoizedScanPlans: make(map[uint32]map[reflect.Type][2]ScanPlan), + + TryWrapEncodePlanFuncs: []TryWrapEncodePlanFunc{ + TryWrapDerefPointerEncodePlan, + TryWrapBuiltinTypeEncodePlan, + TryWrapFindUnderlyingTypeEncodePlan, + TryWrapStructEncodePlan, + TryWrapSliceEncodePlan, + TryWrapMultiDimSliceEncodePlan, + }, + + TryWrapScanPlanFuncs: []TryWrapScanPlanFunc{ + TryPointerPointerScanPlan, + TryWrapBuiltinTypeScanPlan, + TryFindUnderlyingTypeScanPlan, + TryWrapStructScanPlan, + TryWrapPtrSliceScanPlan, + TryWrapPtrMultiDimSliceScanPlan, + }, + } + + // Base types + m.RegisterType(&Type{Name: "aclitem", OID: ACLItemOID, Codec: &TextFormatOnlyCodec{TextCodec{}}}) + m.RegisterType(&Type{Name: "bit", OID: BitOID, Codec: BitsCodec{}}) + m.RegisterType(&Type{Name: "bool", OID: BoolOID, Codec: BoolCodec{}}) + m.RegisterType(&Type{Name: "box", OID: BoxOID, Codec: BoxCodec{}}) + m.RegisterType(&Type{Name: "bpchar", OID: BPCharOID, Codec: TextCodec{}}) + m.RegisterType(&Type{Name: "bytea", OID: ByteaOID, Codec: ByteaCodec{}}) + m.RegisterType(&Type{Name: "char", OID: QCharOID, Codec: QCharCodec{}}) + m.RegisterType(&Type{Name: "cid", OID: CIDOID, Codec: Uint32Codec{}}) + m.RegisterType(&Type{Name: "cidr", OID: CIDROID, Codec: InetCodec{}}) + m.RegisterType(&Type{Name: "circle", OID: CircleOID, Codec: CircleCodec{}}) + m.RegisterType(&Type{Name: "date", OID: DateOID, Codec: DateCodec{}}) + m.RegisterType(&Type{Name: "float4", OID: Float4OID, Codec: Float4Codec{}}) + m.RegisterType(&Type{Name: "float8", OID: Float8OID, Codec: Float8Codec{}}) + m.RegisterType(&Type{Name: "inet", OID: InetOID, Codec: InetCodec{}}) + m.RegisterType(&Type{Name: "int2", OID: Int2OID, Codec: Int2Codec{}}) + m.RegisterType(&Type{Name: "int4", OID: Int4OID, Codec: Int4Codec{}}) + m.RegisterType(&Type{Name: "int8", OID: Int8OID, Codec: Int8Codec{}}) + m.RegisterType(&Type{Name: "interval", OID: IntervalOID, Codec: IntervalCodec{}}) + m.RegisterType(&Type{Name: "json", OID: JSONOID, Codec: JSONCodec{}}) + m.RegisterType(&Type{Name: "jsonb", OID: JSONBOID, Codec: JSONBCodec{}}) + m.RegisterType(&Type{Name: "line", OID: LineOID, Codec: LineCodec{}}) + m.RegisterType(&Type{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}}) + m.RegisterType(&Type{Name: "macaddr", OID: MacaddrOID, Codec: MacaddrCodec{}}) + m.RegisterType(&Type{Name: "name", OID: NameOID, Codec: TextCodec{}}) + m.RegisterType(&Type{Name: "numeric", OID: NumericOID, Codec: NumericCodec{}}) + m.RegisterType(&Type{Name: "oid", OID: OIDOID, Codec: Uint32Codec{}}) + m.RegisterType(&Type{Name: "path", OID: PathOID, Codec: PathCodec{}}) + m.RegisterType(&Type{Name: "point", OID: PointOID, Codec: PointCodec{}}) + m.RegisterType(&Type{Name: "polygon", OID: PolygonOID, Codec: PolygonCodec{}}) + m.RegisterType(&Type{Name: "record", OID: RecordOID, Codec: RecordCodec{}}) + m.RegisterType(&Type{Name: "text", OID: TextOID, Codec: TextCodec{}}) + m.RegisterType(&Type{Name: "tid", OID: TIDOID, Codec: TIDCodec{}}) + m.RegisterType(&Type{Name: "time", OID: TimeOID, Codec: TimeCodec{}}) + m.RegisterType(&Type{Name: "timestamp", OID: TimestampOID, Codec: TimestampCodec{}}) + m.RegisterType(&Type{Name: "timestamptz", OID: TimestamptzOID, Codec: TimestamptzCodec{}}) + m.RegisterType(&Type{Name: "unknown", OID: UnknownOID, Codec: TextCodec{}}) + m.RegisterType(&Type{Name: "uuid", OID: UUIDOID, Codec: UUIDCodec{}}) + m.RegisterType(&Type{Name: "varbit", OID: VarbitOID, Codec: BitsCodec{}}) + m.RegisterType(&Type{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}}) + m.RegisterType(&Type{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}}) + + // Range types + m.RegisterType(&Type{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec{ElementType: m.oidToType[DateOID]}}) + m.RegisterType(&Type{Name: "int4range", OID: Int4rangeOID, Codec: &RangeCodec{ElementType: m.oidToType[Int4OID]}}) + m.RegisterType(&Type{Name: "int8range", OID: Int8rangeOID, Codec: &RangeCodec{ElementType: m.oidToType[Int8OID]}}) + m.RegisterType(&Type{Name: "numrange", OID: NumrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[NumericOID]}}) + m.RegisterType(&Type{Name: "tsrange", OID: TsrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[TimestampOID]}}) + m.RegisterType(&Type{Name: "tstzrange", OID: TstzrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[TimestamptzOID]}}) + + // Multirange types + m.RegisterType(&Type{Name: "datemultirange", OID: DatemultirangeOID, Codec: &MultirangeCodec{ElementType: m.oidToType[DaterangeOID]}}) + m.RegisterType(&Type{Name: "int4multirange", OID: Int4multirangeOID, Codec: &MultirangeCodec{ElementType: m.oidToType[Int4rangeOID]}}) + m.RegisterType(&Type{Name: "int8multirange", OID: Int8multirangeOID, Codec: &MultirangeCodec{ElementType: m.oidToType[Int8rangeOID]}}) + m.RegisterType(&Type{Name: "nummultirange", OID: NummultirangeOID, Codec: &MultirangeCodec{ElementType: m.oidToType[NumrangeOID]}}) + m.RegisterType(&Type{Name: "tsmultirange", OID: TsmultirangeOID, Codec: &MultirangeCodec{ElementType: m.oidToType[TsrangeOID]}}) + m.RegisterType(&Type{Name: "tstzmultirange", OID: TstzmultirangeOID, Codec: &MultirangeCodec{ElementType: m.oidToType[TstzrangeOID]}}) + + // Array types + m.RegisterType(&Type{Name: "_aclitem", OID: ACLItemArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[ACLItemOID]}}) + m.RegisterType(&Type{Name: "_bit", OID: BitArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[BitOID]}}) + m.RegisterType(&Type{Name: "_bool", OID: BoolArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[BoolOID]}}) + m.RegisterType(&Type{Name: "_box", OID: BoxArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[BoxOID]}}) + m.RegisterType(&Type{Name: "_bpchar", OID: BPCharArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[BPCharOID]}}) + m.RegisterType(&Type{Name: "_bytea", OID: ByteaArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[ByteaOID]}}) + m.RegisterType(&Type{Name: "_char", OID: QCharArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[QCharOID]}}) + m.RegisterType(&Type{Name: "_cid", OID: CIDArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[CIDOID]}}) + m.RegisterType(&Type{Name: "_cidr", OID: CIDRArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[CIDROID]}}) + m.RegisterType(&Type{Name: "_circle", OID: CircleArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[CircleOID]}}) + m.RegisterType(&Type{Name: "_date", OID: DateArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[DateOID]}}) + m.RegisterType(&Type{Name: "_daterange", OID: DaterangeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[DaterangeOID]}}) + m.RegisterType(&Type{Name: "_float4", OID: Float4ArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Float4OID]}}) + m.RegisterType(&Type{Name: "_float8", OID: Float8ArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Float8OID]}}) + m.RegisterType(&Type{Name: "_inet", OID: InetArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[InetOID]}}) + m.RegisterType(&Type{Name: "_int2", OID: Int2ArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Int2OID]}}) + m.RegisterType(&Type{Name: "_int4", OID: Int4ArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Int4OID]}}) + m.RegisterType(&Type{Name: "_int4range", OID: Int4rangeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Int4rangeOID]}}) + m.RegisterType(&Type{Name: "_int8", OID: Int8ArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Int8OID]}}) + m.RegisterType(&Type{Name: "_int8range", OID: Int8rangeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Int8rangeOID]}}) + m.RegisterType(&Type{Name: "_interval", OID: IntervalArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[IntervalOID]}}) + m.RegisterType(&Type{Name: "_json", OID: JSONArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[JSONOID]}}) + m.RegisterType(&Type{Name: "_jsonb", OID: JSONBArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[JSONBOID]}}) + m.RegisterType(&Type{Name: "_line", OID: LineArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[LineOID]}}) + m.RegisterType(&Type{Name: "_lseg", OID: LsegArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[LsegOID]}}) + m.RegisterType(&Type{Name: "_macaddr", OID: MacaddrArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[MacaddrOID]}}) + m.RegisterType(&Type{Name: "_name", OID: NameArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[NameOID]}}) + m.RegisterType(&Type{Name: "_numeric", OID: NumericArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[NumericOID]}}) + m.RegisterType(&Type{Name: "_numrange", OID: NumrangeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[NumrangeOID]}}) + m.RegisterType(&Type{Name: "_oid", OID: OIDArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[OIDOID]}}) + m.RegisterType(&Type{Name: "_path", OID: PathArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[PathOID]}}) + m.RegisterType(&Type{Name: "_point", OID: PointArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[PointOID]}}) + m.RegisterType(&Type{Name: "_polygon", OID: PolygonArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[PolygonOID]}}) + m.RegisterType(&Type{Name: "_record", OID: RecordArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[RecordOID]}}) + m.RegisterType(&Type{Name: "_text", OID: TextArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TextOID]}}) + m.RegisterType(&Type{Name: "_tid", OID: TIDArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TIDOID]}}) + m.RegisterType(&Type{Name: "_time", OID: TimeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TimeOID]}}) + m.RegisterType(&Type{Name: "_timestamp", OID: TimestampArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TimestampOID]}}) + m.RegisterType(&Type{Name: "_timestamptz", OID: TimestamptzArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TimestamptzOID]}}) + m.RegisterType(&Type{Name: "_tsrange", OID: TsrangeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TsrangeOID]}}) + m.RegisterType(&Type{Name: "_tstzrange", OID: TstzrangeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TstzrangeOID]}}) + m.RegisterType(&Type{Name: "_uuid", OID: UUIDArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[UUIDOID]}}) + m.RegisterType(&Type{Name: "_varbit", OID: VarbitArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[VarbitOID]}}) + m.RegisterType(&Type{Name: "_varchar", OID: VarcharArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[VarcharOID]}}) + m.RegisterType(&Type{Name: "_xid", OID: XIDArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[XIDOID]}}) + + // Integer types that directly map to a PostgreSQL type + registerDefaultPgTypeVariants[int16](m, "int2") + registerDefaultPgTypeVariants[int32](m, "int4") + registerDefaultPgTypeVariants[int64](m, "int8") + + // Integer types that do not have a direct match to a PostgreSQL type + registerDefaultPgTypeVariants[int8](m, "int8") + registerDefaultPgTypeVariants[int](m, "int8") + registerDefaultPgTypeVariants[uint8](m, "int8") + registerDefaultPgTypeVariants[uint16](m, "int8") + registerDefaultPgTypeVariants[uint32](m, "int8") + registerDefaultPgTypeVariants[uint64](m, "int8") + registerDefaultPgTypeVariants[uint](m, "int8") + + registerDefaultPgTypeVariants[float32](m, "float4") + registerDefaultPgTypeVariants[float64](m, "float8") + + registerDefaultPgTypeVariants[bool](m, "bool") + registerDefaultPgTypeVariants[time.Time](m, "timestamptz") + registerDefaultPgTypeVariants[time.Duration](m, "interval") + registerDefaultPgTypeVariants[string](m, "text") + registerDefaultPgTypeVariants[[]byte](m, "bytea") + + registerDefaultPgTypeVariants[net.IP](m, "inet") + registerDefaultPgTypeVariants[net.IPNet](m, "cidr") + registerDefaultPgTypeVariants[netip.Addr](m, "inet") + registerDefaultPgTypeVariants[netip.Prefix](m, "cidr") + + // pgtype provided structs + registerDefaultPgTypeVariants[Bits](m, "varbit") + registerDefaultPgTypeVariants[Bool](m, "bool") + registerDefaultPgTypeVariants[Box](m, "box") + registerDefaultPgTypeVariants[Circle](m, "circle") + registerDefaultPgTypeVariants[Date](m, "date") + registerDefaultPgTypeVariants[Range[Date]](m, "daterange") + registerDefaultPgTypeVariants[Multirange[Range[Date]]](m, "datemultirange") + registerDefaultPgTypeVariants[Float4](m, "float4") + registerDefaultPgTypeVariants[Float8](m, "float8") + registerDefaultPgTypeVariants[Range[Float8]](m, "numrange") // There is no PostgreSQL builtin float8range so map it to numrange. + registerDefaultPgTypeVariants[Multirange[Range[Float8]]](m, "nummultirange") // There is no PostgreSQL builtin float8multirange so map it to nummultirange. + registerDefaultPgTypeVariants[Int2](m, "int2") + registerDefaultPgTypeVariants[Int4](m, "int4") + registerDefaultPgTypeVariants[Range[Int4]](m, "int4range") + registerDefaultPgTypeVariants[Multirange[Range[Int4]]](m, "int4multirange") + registerDefaultPgTypeVariants[Int8](m, "int8") + registerDefaultPgTypeVariants[Range[Int8]](m, "int8range") + registerDefaultPgTypeVariants[Multirange[Range[Int8]]](m, "int8multirange") + registerDefaultPgTypeVariants[Interval](m, "interval") + registerDefaultPgTypeVariants[Line](m, "line") + registerDefaultPgTypeVariants[Lseg](m, "lseg") + registerDefaultPgTypeVariants[Numeric](m, "numeric") + registerDefaultPgTypeVariants[Range[Numeric]](m, "numrange") + registerDefaultPgTypeVariants[Multirange[Range[Numeric]]](m, "nummultirange") + registerDefaultPgTypeVariants[Path](m, "path") + registerDefaultPgTypeVariants[Point](m, "point") + registerDefaultPgTypeVariants[Polygon](m, "polygon") + registerDefaultPgTypeVariants[TID](m, "tid") + registerDefaultPgTypeVariants[Text](m, "text") + registerDefaultPgTypeVariants[Time](m, "time") + registerDefaultPgTypeVariants[Timestamp](m, "timestamp") + registerDefaultPgTypeVariants[Timestamptz](m, "timestamptz") + registerDefaultPgTypeVariants[Range[Timestamp]](m, "tsrange") + registerDefaultPgTypeVariants[Multirange[Range[Timestamp]]](m, "tsmultirange") + registerDefaultPgTypeVariants[Range[Timestamptz]](m, "tstzrange") + registerDefaultPgTypeVariants[Multirange[Range[Timestamptz]]](m, "tstzmultirange") + registerDefaultPgTypeVariants[UUID](m, "uuid") + + return m +} + +func (m *Map) RegisterType(t *Type) { + m.oidToType[t.OID] = t + m.nameToType[t.Name] = t + m.oidToFormatCode[t.OID] = t.Codec.PreferredFormat() + + // Invalidated by type registration + m.reflectTypeToType = nil + for k := range m.memoizedScanPlans { + delete(m.memoizedScanPlans, k) + } +} + +// RegisterDefaultPgType registers a mapping of a Go type to a PostgreSQL type name. Typically the data type to be +// encoded or decoded is determined by the PostgreSQL OID. But if the OID of a value to be encoded or decoded is +// unknown, this additional mapping will be used by TypeForValue to determine a suitable data type. +func (m *Map) RegisterDefaultPgType(value any, name string) { + m.reflectTypeToName[reflect.TypeOf(value)] = name + + // Invalidated by type registration + m.reflectTypeToType = nil + for k := range m.memoizedScanPlans { + delete(m.memoizedScanPlans, k) + } +} + +func (m *Map) TypeForOID(oid uint32) (*Type, bool) { + dt, ok := m.oidToType[oid] + return dt, ok +} + +func (m *Map) TypeForName(name string) (*Type, bool) { + dt, ok := m.nameToType[name] + return dt, ok +} + +func (m *Map) buildReflectTypeToType() { + m.reflectTypeToType = make(map[reflect.Type]*Type) + + for reflectType, name := range m.reflectTypeToName { + if dt, ok := m.nameToType[name]; ok { + m.reflectTypeToType[reflectType] = dt + } + } +} + +// TypeForValue finds a data type suitable for v. Use RegisterType to register types that can encode and decode +// themselves. Use RegisterDefaultPgType to register that can be handled by a registered data type. +func (m *Map) TypeForValue(v any) (*Type, bool) { + if m.reflectTypeToType == nil { + m.buildReflectTypeToType() + } + + dt, ok := m.reflectTypeToType[reflect.TypeOf(v)] + return dt, ok +} + +// FormatCodeForOID returns the preferred format code for type oid. If the type is not registered it returns the text +// format code. +func (m *Map) FormatCodeForOID(oid uint32) int16 { + fc, ok := m.oidToFormatCode[oid] + if ok { + return fc + } + return TextFormatCode +} + +// EncodePlan is a precompiled plan to encode a particular type into a particular OID and format. +type EncodePlan interface { + // Encode appends the encoded bytes of value to buf. If value is the SQL value NULL then append nothing and return + // (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data + // written. + Encode(value any, buf []byte) (newBuf []byte, err error) +} + +// ScanPlan is a precompiled plan to scan into a type of destination. +type ScanPlan interface { + // Scan scans src into target. src is only valid during the call to Scan. The ScanPlan must not retain a reference to + // src. + Scan(src []byte, target any) error +} + +type scanPlanCodecSQLScanner struct { + c Codec + m *Map + oid uint32 + formatCode int16 +} + +func (plan *scanPlanCodecSQLScanner) Scan(src []byte, dst any) error { + value, err := plan.c.DecodeDatabaseSQLValue(plan.m, plan.oid, plan.formatCode, src) + if err != nil { + return err + } + + scanner := dst.(sql.Scanner) + return scanner.Scan(value) +} + +type scanPlanSQLScanner struct { + formatCode int16 +} + +func (plan *scanPlanSQLScanner) Scan(src []byte, dst any) error { + scanner := dst.(sql.Scanner) + if src == nil { + // This is necessary because interface value []byte:nil does not equal nil:nil for the binary format path and the + // text format path would be converted to empty string. + return scanner.Scan(nil) + } else if plan.formatCode == BinaryFormatCode { + return scanner.Scan(src) + } else { + return scanner.Scan(string(src)) + } +} + +type scanPlanString struct{} + +func (scanPlanString) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + p := (dst).(*string) + *p = string(src) + return nil +} + +type scanPlanAnyTextToBytes struct{} + +func (scanPlanAnyTextToBytes) Scan(src []byte, dst any) error { + dstBuf := dst.(*[]byte) + if src == nil { + *dstBuf = nil + return nil + } + + *dstBuf = make([]byte, len(src)) + copy(*dstBuf, src) + return nil +} + +type scanPlanFail struct { + m *Map + oid uint32 + formatCode int16 +} + +func (plan *scanPlanFail) Scan(src []byte, dst any) error { + var format string + switch plan.formatCode { + case TextFormatCode: + format = "text" + case BinaryFormatCode: + format = "binary" + default: + format = fmt.Sprintf("unknown %d", plan.formatCode) + } + + var dataTypeName string + if t, ok := plan.m.oidToType[plan.oid]; ok { + dataTypeName = t.Name + } else { + dataTypeName = "unknown type" + } + + return fmt.Errorf("cannot scan %s (OID %d) in %v format into %T", dataTypeName, plan.oid, format, dst) +} + +// TryWrapScanPlanFunc is a function that tries to create a wrapper plan for target. If successful it returns a plan +// that will convert the target passed to Scan and then call the next plan. nextTarget is target as it will be converted +// by plan. It must be used to find another suitable ScanPlan. When it is found SetNext must be called on plan for it +// to be usabled. ok indicates if a suitable wrapper was found. +type TryWrapScanPlanFunc func(target any) (plan WrappedScanPlanNextSetter, nextTarget any, ok bool) + +type pointerPointerScanPlan struct { + dstType reflect.Type + next ScanPlan +} + +func (plan *pointerPointerScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *pointerPointerScanPlan) Scan(src []byte, dst any) error { + el := reflect.ValueOf(dst).Elem() + if src == nil { + el.Set(reflect.Zero(el.Type())) + return nil + } + + el.Set(reflect.New(el.Type().Elem())) + return plan.next.Scan(src, el.Interface()) +} + +// TryPointerPointerScanPlan handles a pointer to a pointer by setting the target to nil for SQL NULL and allocating and +// scanning for non-NULL. +func TryPointerPointerScanPlan(target any) (plan WrappedScanPlanNextSetter, nextTarget any, ok bool) { + if dstValue := reflect.ValueOf(target); dstValue.Kind() == reflect.Ptr { + elemValue := dstValue.Elem() + if elemValue.Kind() == reflect.Ptr { + plan = &pointerPointerScanPlan{dstType: dstValue.Type()} + return plan, reflect.Zero(elemValue.Type()).Interface(), true + } + } + + return nil, nil, false +} + +// SkipUnderlyingTypePlanner prevents PlanScan and PlanDecode from trying to use the underlying type. +type SkipUnderlyingTypePlanner interface { + SkipUnderlyingTypePlan() +} + +var elemKindToPointerTypes map[reflect.Kind]reflect.Type = map[reflect.Kind]reflect.Type{ + reflect.Int: reflect.TypeOf(new(int)), + reflect.Int8: reflect.TypeOf(new(int8)), + reflect.Int16: reflect.TypeOf(new(int16)), + reflect.Int32: reflect.TypeOf(new(int32)), + reflect.Int64: reflect.TypeOf(new(int64)), + reflect.Uint: reflect.TypeOf(new(uint)), + reflect.Uint8: reflect.TypeOf(new(uint8)), + reflect.Uint16: reflect.TypeOf(new(uint16)), + reflect.Uint32: reflect.TypeOf(new(uint32)), + reflect.Uint64: reflect.TypeOf(new(uint64)), + reflect.Float32: reflect.TypeOf(new(float32)), + reflect.Float64: reflect.TypeOf(new(float64)), + reflect.String: reflect.TypeOf(new(string)), +} + +type underlyingTypeScanPlan struct { + dstType reflect.Type + nextDstType reflect.Type + next ScanPlan +} + +func (plan *underlyingTypeScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *underlyingTypeScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, reflect.ValueOf(dst).Convert(plan.nextDstType).Interface()) +} + +// TryFindUnderlyingTypeScanPlan tries to convert to a Go builtin type. e.g. If value was of type MyString and +// MyString was defined as a string then a wrapper plan would be returned that converts MyString to string. +func TryFindUnderlyingTypeScanPlan(dst any) (plan WrappedScanPlanNextSetter, nextDst any, ok bool) { + if _, ok := dst.(SkipUnderlyingTypePlanner); ok { + return nil, nil, false + } + + dstValue := reflect.ValueOf(dst) + + if dstValue.Kind() == reflect.Ptr { + elemValue := dstValue.Elem() + nextDstType := elemKindToPointerTypes[elemValue.Kind()] + if nextDstType == nil && elemValue.Kind() == reflect.Slice { + if elemValue.Type().Elem().Kind() == reflect.Uint8 { + var v *[]byte + nextDstType = reflect.TypeOf(v) + } + } + + if nextDstType != nil && dstValue.Type() != nextDstType { + return &underlyingTypeScanPlan{dstType: dstValue.Type(), nextDstType: nextDstType}, dstValue.Convert(nextDstType).Interface(), true + } + + } + + return nil, nil, false +} + +type WrappedScanPlanNextSetter interface { + SetNext(ScanPlan) + ScanPlan +} + +// TryWrapBuiltinTypeScanPlan tries to wrap a builtin type with a wrapper that provides additional methods. e.g. If +// value was of type int32 then a wrapper plan would be returned that converts target to a value that implements +// Int64Scanner. +func TryWrapBuiltinTypeScanPlan(target any) (plan WrappedScanPlanNextSetter, nextDst any, ok bool) { + switch target := target.(type) { + case *int8: + return &wrapInt8ScanPlan{}, (*int8Wrapper)(target), true + case *int16: + return &wrapInt16ScanPlan{}, (*int16Wrapper)(target), true + case *int32: + return &wrapInt32ScanPlan{}, (*int32Wrapper)(target), true + case *int64: + return &wrapInt64ScanPlan{}, (*int64Wrapper)(target), true + case *int: + return &wrapIntScanPlan{}, (*intWrapper)(target), true + case *uint8: + return &wrapUint8ScanPlan{}, (*uint8Wrapper)(target), true + case *uint16: + return &wrapUint16ScanPlan{}, (*uint16Wrapper)(target), true + case *uint32: + return &wrapUint32ScanPlan{}, (*uint32Wrapper)(target), true + case *uint64: + return &wrapUint64ScanPlan{}, (*uint64Wrapper)(target), true + case *uint: + return &wrapUintScanPlan{}, (*uintWrapper)(target), true + case *float32: + return &wrapFloat32ScanPlan{}, (*float32Wrapper)(target), true + case *float64: + return &wrapFloat64ScanPlan{}, (*float64Wrapper)(target), true + case *string: + return &wrapStringScanPlan{}, (*stringWrapper)(target), true + case *time.Time: + return &wrapTimeScanPlan{}, (*timeWrapper)(target), true + case *time.Duration: + return &wrapDurationScanPlan{}, (*durationWrapper)(target), true + case *net.IPNet: + return &wrapNetIPNetScanPlan{}, (*netIPNetWrapper)(target), true + case *net.IP: + return &wrapNetIPScanPlan{}, (*netIPWrapper)(target), true + case *netip.Prefix: + return &wrapNetipPrefixScanPlan{}, (*netipPrefixWrapper)(target), true + case *netip.Addr: + return &wrapNetipAddrScanPlan{}, (*netipAddrWrapper)(target), true + case *map[string]*string: + return &wrapMapStringToPointerStringScanPlan{}, (*mapStringToPointerStringWrapper)(target), true + case *map[string]string: + return &wrapMapStringToStringScanPlan{}, (*mapStringToStringWrapper)(target), true + case *[16]byte: + return &wrapByte16ScanPlan{}, (*byte16Wrapper)(target), true + case *[]byte: + return &wrapByteSliceScanPlan{}, (*byteSliceWrapper)(target), true + } + + return nil, nil, false +} + +type wrapInt8ScanPlan struct { + next ScanPlan +} + +func (plan *wrapInt8ScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapInt8ScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*int8Wrapper)(dst.(*int8))) +} + +type wrapInt16ScanPlan struct { + next ScanPlan +} + +func (plan *wrapInt16ScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapInt16ScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*int16Wrapper)(dst.(*int16))) +} + +type wrapInt32ScanPlan struct { + next ScanPlan +} + +func (plan *wrapInt32ScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapInt32ScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*int32Wrapper)(dst.(*int32))) +} + +type wrapInt64ScanPlan struct { + next ScanPlan +} + +func (plan *wrapInt64ScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapInt64ScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*int64Wrapper)(dst.(*int64))) +} + +type wrapIntScanPlan struct { + next ScanPlan +} + +func (plan *wrapIntScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapIntScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*intWrapper)(dst.(*int))) +} + +type wrapUint8ScanPlan struct { + next ScanPlan +} + +func (plan *wrapUint8ScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapUint8ScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*uint8Wrapper)(dst.(*uint8))) +} + +type wrapUint16ScanPlan struct { + next ScanPlan +} + +func (plan *wrapUint16ScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapUint16ScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*uint16Wrapper)(dst.(*uint16))) +} + +type wrapUint32ScanPlan struct { + next ScanPlan +} + +func (plan *wrapUint32ScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapUint32ScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*uint32Wrapper)(dst.(*uint32))) +} + +type wrapUint64ScanPlan struct { + next ScanPlan +} + +func (plan *wrapUint64ScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapUint64ScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*uint64Wrapper)(dst.(*uint64))) +} + +type wrapUintScanPlan struct { + next ScanPlan +} + +func (plan *wrapUintScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapUintScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*uintWrapper)(dst.(*uint))) +} + +type wrapFloat32ScanPlan struct { + next ScanPlan +} + +func (plan *wrapFloat32ScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapFloat32ScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*float32Wrapper)(dst.(*float32))) +} + +type wrapFloat64ScanPlan struct { + next ScanPlan +} + +func (plan *wrapFloat64ScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapFloat64ScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*float64Wrapper)(dst.(*float64))) +} + +type wrapStringScanPlan struct { + next ScanPlan +} + +func (plan *wrapStringScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapStringScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*stringWrapper)(dst.(*string))) +} + +type wrapTimeScanPlan struct { + next ScanPlan +} + +func (plan *wrapTimeScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapTimeScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*timeWrapper)(dst.(*time.Time))) +} + +type wrapDurationScanPlan struct { + next ScanPlan +} + +func (plan *wrapDurationScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapDurationScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*durationWrapper)(dst.(*time.Duration))) +} + +type wrapNetIPNetScanPlan struct { + next ScanPlan +} + +func (plan *wrapNetIPNetScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapNetIPNetScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*netIPNetWrapper)(dst.(*net.IPNet))) +} + +type wrapNetIPScanPlan struct { + next ScanPlan +} + +func (plan *wrapNetIPScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapNetIPScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*netIPWrapper)(dst.(*net.IP))) +} + +type wrapNetipPrefixScanPlan struct { + next ScanPlan +} + +func (plan *wrapNetipPrefixScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapNetipPrefixScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*netipPrefixWrapper)(dst.(*netip.Prefix))) +} + +type wrapNetipAddrScanPlan struct { + next ScanPlan +} + +func (plan *wrapNetipAddrScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapNetipAddrScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*netipAddrWrapper)(dst.(*netip.Addr))) +} + +type wrapMapStringToPointerStringScanPlan struct { + next ScanPlan +} + +func (plan *wrapMapStringToPointerStringScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapMapStringToPointerStringScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*mapStringToPointerStringWrapper)(dst.(*map[string]*string))) +} + +type wrapMapStringToStringScanPlan struct { + next ScanPlan +} + +func (plan *wrapMapStringToStringScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapMapStringToStringScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*mapStringToStringWrapper)(dst.(*map[string]string))) +} + +type wrapByte16ScanPlan struct { + next ScanPlan +} + +func (plan *wrapByte16ScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapByte16ScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*byte16Wrapper)(dst.(*[16]byte))) +} + +type wrapByteSliceScanPlan struct { + next ScanPlan +} + +func (plan *wrapByteSliceScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapByteSliceScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*byteSliceWrapper)(dst.(*[]byte))) +} + +type pointerEmptyInterfaceScanPlan struct { + codec Codec + m *Map + oid uint32 + formatCode int16 +} + +func (plan *pointerEmptyInterfaceScanPlan) Scan(src []byte, dst any) error { + value, err := plan.codec.DecodeValue(plan.m, plan.oid, plan.formatCode, src) + if err != nil { + return err + } + + ptrAny := dst.(*any) + *ptrAny = value + + return nil +} + +// TryWrapStructPlan tries to wrap a struct with a wrapper that implements CompositeIndexGetter. +func TryWrapStructScanPlan(target any) (plan WrappedScanPlanNextSetter, nextValue any, ok bool) { + targetValue := reflect.ValueOf(target) + if targetValue.Kind() != reflect.Ptr { + return nil, nil, false + } + + var targetElemValue reflect.Value + if targetValue.IsNil() { + targetElemValue = reflect.New(targetValue.Type().Elem()) + } else { + targetElemValue = targetValue.Elem() + } + targetElemType := targetElemValue.Type() + + if targetElemType.Kind() == reflect.Struct { + exportedFields := getExportedFieldValues(targetElemValue) + if len(exportedFields) == 0 { + return nil, nil, false + } + + w := ptrStructWrapper{ + s: target, + exportedFields: exportedFields, + } + return &wrapAnyPtrStructScanPlan{}, &w, true + } + + return nil, nil, false +} + +type wrapAnyPtrStructScanPlan struct { + next ScanPlan +} + +func (plan *wrapAnyPtrStructScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapAnyPtrStructScanPlan) Scan(src []byte, target any) error { + w := ptrStructWrapper{ + s: target, + exportedFields: getExportedFieldValues(reflect.ValueOf(target).Elem()), + } + + return plan.next.Scan(src, &w) +} + +// TryWrapPtrSliceScanPlan tries to wrap a pointer to a single dimension slice. +func TryWrapPtrSliceScanPlan(target any) (plan WrappedScanPlanNextSetter, nextValue any, ok bool) { + // Avoid using reflect path for common types. + switch target := target.(type) { + case *[]int16: + return &wrapPtrSliceScanPlan[int16]{}, (*FlatArray[int16])(target), true + case *[]int32: + return &wrapPtrSliceScanPlan[int32]{}, (*FlatArray[int32])(target), true + case *[]int64: + return &wrapPtrSliceScanPlan[int64]{}, (*FlatArray[int64])(target), true + case *[]float32: + return &wrapPtrSliceScanPlan[float32]{}, (*FlatArray[float32])(target), true + case *[]float64: + return &wrapPtrSliceScanPlan[float64]{}, (*FlatArray[float64])(target), true + case *[]string: + return &wrapPtrSliceScanPlan[string]{}, (*FlatArray[string])(target), true + case *[]time.Time: + return &wrapPtrSliceScanPlan[time.Time]{}, (*FlatArray[time.Time])(target), true + } + + targetValue := reflect.ValueOf(target) + if targetValue.Kind() != reflect.Ptr { + return nil, nil, false + } + + targetElemValue := targetValue.Elem() + + if targetElemValue.Kind() == reflect.Slice { + return &wrapPtrSliceReflectScanPlan{}, &anySliceArrayReflect{slice: targetElemValue}, true + } + return nil, nil, false +} + +type wrapPtrSliceScanPlan[T any] struct { + next ScanPlan +} + +func (plan *wrapPtrSliceScanPlan[T]) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapPtrSliceScanPlan[T]) Scan(src []byte, target any) error { + return plan.next.Scan(src, (*FlatArray[T])(target.(*[]T))) +} + +type wrapPtrSliceReflectScanPlan struct { + next ScanPlan +} + +func (plan *wrapPtrSliceReflectScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapPtrSliceReflectScanPlan) Scan(src []byte, target any) error { + return plan.next.Scan(src, &anySliceArrayReflect{slice: reflect.ValueOf(target).Elem()}) +} + +// TryWrapPtrMultiDimSliceScanPlan tries to wrap a pointer to a multi-dimension slice. +func TryWrapPtrMultiDimSliceScanPlan(target any) (plan WrappedScanPlanNextSetter, nextValue any, ok bool) { + targetValue := reflect.ValueOf(target) + if targetValue.Kind() != reflect.Ptr { + return nil, nil, false + } + + targetElemValue := targetValue.Elem() + + if targetElemValue.Kind() == reflect.Slice { + elemElemKind := targetElemValue.Type().Elem().Kind() + if elemElemKind == reflect.Slice { + if !isRagged(targetElemValue) { + return &wrapPtrMultiDimSliceScanPlan{}, &anyMultiDimSliceArray{slice: targetValue.Elem()}, true + } + } + } + + return nil, nil, false +} + +type wrapPtrMultiDimSliceScanPlan struct { + next ScanPlan +} + +func (plan *wrapPtrMultiDimSliceScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapPtrMultiDimSliceScanPlan) Scan(src []byte, target any) error { + return plan.next.Scan(src, &anyMultiDimSliceArray{slice: reflect.ValueOf(target).Elem()}) +} + +// PlanScan prepares a plan to scan a value into target. +func (m *Map) PlanScan(oid uint32, formatCode int16, target any) ScanPlan { + oidMemo := m.memoizedScanPlans[oid] + if oidMemo == nil { + oidMemo = make(map[reflect.Type][2]ScanPlan) + m.memoizedScanPlans[oid] = oidMemo + } + targetReflectType := reflect.TypeOf(target) + typeMemo := oidMemo[targetReflectType] + plan := typeMemo[formatCode] + if plan == nil { + plan = m.planScan(oid, formatCode, target) + typeMemo[formatCode] = plan + oidMemo[targetReflectType] = typeMemo + } + + return plan +} + +func (m *Map) planScan(oid uint32, formatCode int16, target any) ScanPlan { + if _, ok := target.(*UndecodedBytes); ok { + return scanPlanAnyToUndecodedBytes{} + } + + switch formatCode { + case BinaryFormatCode: + switch target.(type) { + case *string: + switch oid { + case TextOID, VarcharOID: + return scanPlanString{} + } + } + case TextFormatCode: + switch target.(type) { + case *string: + return scanPlanString{} + case *[]byte: + if oid != ByteaOID { + return scanPlanAnyTextToBytes{} + } + case TextScanner: + return scanPlanTextAnyToTextScanner{} + } + } + + var dt *Type + + if dataType, ok := m.TypeForOID(oid); ok { + dt = dataType + } else if dataType, ok := m.TypeForValue(target); ok { + dt = dataType + oid = dt.OID // Preserve assumed OID in case we are recursively called below. + } + + if dt != nil { + if plan := dt.Codec.PlanScan(m, oid, formatCode, target); plan != nil { + return plan + } + } + + for _, f := range m.TryWrapScanPlanFuncs { + if wrapperPlan, nextDst, ok := f(target); ok { + if nextPlan := m.planScan(oid, formatCode, nextDst); nextPlan != nil { + if _, failed := nextPlan.(*scanPlanFail); !failed { + wrapperPlan.SetNext(nextPlan) + return wrapperPlan + } + } + } + } + + if dt != nil { + if _, ok := target.(*any); ok { + return &pointerEmptyInterfaceScanPlan{codec: dt.Codec, m: m, oid: oid, formatCode: formatCode} + } + + if _, ok := target.(sql.Scanner); ok { + return &scanPlanCodecSQLScanner{c: dt.Codec, m: m, oid: oid, formatCode: formatCode} + } + } + + if _, ok := target.(sql.Scanner); ok { + return &scanPlanSQLScanner{formatCode: formatCode} + } + + return &scanPlanFail{m: m, oid: oid, formatCode: formatCode} +} + +func (m *Map) Scan(oid uint32, formatCode int16, src []byte, dst any) error { + if dst == nil { + return nil + } + + plan := m.PlanScan(oid, formatCode, dst) + return plan.Scan(src, dst) +} + +func scanUnknownType(oid uint32, formatCode int16, buf []byte, dest any) error { + switch dest := dest.(type) { + case *string: + if formatCode == BinaryFormatCode { + return fmt.Errorf("unknown oid %d in binary format cannot be scanned into %T", oid, dest) + } + *dest = string(buf) + return nil + case *[]byte: + *dest = buf + return nil + default: + if nextDst, retry := GetAssignToDstType(dest); retry { + return scanUnknownType(oid, formatCode, buf, nextDst) + } + return fmt.Errorf("unknown oid %d cannot be scanned into %T", oid, dest) + } +} + +var ErrScanTargetTypeChanged = errors.New("scan target type changed") + +func codecScan(codec Codec, m *Map, oid uint32, format int16, src []byte, dst any) error { + scanPlan := codec.PlanScan(m, oid, format, dst) + if scanPlan == nil { + return fmt.Errorf("PlanScan did not find a plan") + } + return scanPlan.Scan(src, dst) +} + +func codecDecodeToTextFormat(codec Codec, m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + if format == TextFormatCode { + return string(src), nil + } else { + value, err := codec.DecodeValue(m, oid, format, src) + if err != nil { + return nil, err + } + buf, err := m.Encode(oid, TextFormatCode, value, nil) + if err != nil { + return nil, err + } + return string(buf), nil + } +} + +// PlanEncode returns an Encode plan for encoding value into PostgreSQL format for oid and format. If no plan can be +// found then nil is returned. +func (m *Map) PlanEncode(oid uint32, format int16, value any) EncodePlan { + if format == TextFormatCode { + switch value.(type) { + case string: + return encodePlanStringToAnyTextFormat{} + case TextValuer: + return encodePlanTextValuerToAnyTextFormat{} + } + } + + var dt *Type + + if oid == 0 { + if dataType, ok := m.TypeForValue(value); ok { + dt = dataType + oid = dt.OID // Preserve assumed OID in case we are recursively called below. + } + } else { + if dataType, ok := m.TypeForOID(oid); ok { + dt = dataType + } + } + + if dt != nil { + if plan := dt.Codec.PlanEncode(m, oid, format, value); plan != nil { + return plan + } + } + + for _, f := range m.TryWrapEncodePlanFuncs { + if wrapperPlan, nextValue, ok := f(value); ok { + if nextPlan := m.PlanEncode(oid, format, nextValue); nextPlan != nil { + wrapperPlan.SetNext(nextPlan) + return wrapperPlan + } + } + } + + return nil +} + +type encodePlanStringToAnyTextFormat struct{} + +func (encodePlanStringToAnyTextFormat) Encode(value any, buf []byte) (newBuf []byte, err error) { + s := value.(string) + return append(buf, s...), nil +} + +type encodePlanTextValuerToAnyTextFormat struct{} + +func (encodePlanTextValuerToAnyTextFormat) Encode(value any, buf []byte) (newBuf []byte, err error) { + t, err := value.(TextValuer).TextValue() + if err != nil { + return nil, err + } + if !t.Valid { + return nil, nil + } + + return append(buf, t.String...), nil +} + +// TryWrapEncodePlanFunc is a function that tries to create a wrapper plan for value. If successful it returns a plan +// that will convert the value passed to Encode and then call the next plan. nextValue is value as it will be converted +// by plan. It must be used to find another suitable EncodePlan. When it is found SetNext must be called on plan for it +// to be usabled. ok indicates if a suitable wrapper was found. +type TryWrapEncodePlanFunc func(value any) (plan WrappedEncodePlanNextSetter, nextValue any, ok bool) + +type derefPointerEncodePlan struct { + next EncodePlan +} + +func (plan *derefPointerEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *derefPointerEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + ptr := reflect.ValueOf(value) + + if ptr.IsNil() { + return nil, nil + } + + return plan.next.Encode(ptr.Elem().Interface(), buf) +} + +// TryWrapDerefPointerEncodePlan tries to dereference a pointer. e.g. If value was of type *string then a wrapper plan +// would be returned that derefences the value. +func TryWrapDerefPointerEncodePlan(value any) (plan WrappedEncodePlanNextSetter, nextValue any, ok bool) { + if valueType := reflect.TypeOf(value); valueType.Kind() == reflect.Ptr { + return &derefPointerEncodePlan{}, reflect.New(valueType.Elem()).Elem().Interface(), true + } + + return nil, nil, false +} + +var kindToTypes map[reflect.Kind]reflect.Type = map[reflect.Kind]reflect.Type{ + reflect.Int: reflect.TypeOf(int(0)), + reflect.Int8: reflect.TypeOf(int8(0)), + reflect.Int16: reflect.TypeOf(int16(0)), + reflect.Int32: reflect.TypeOf(int32(0)), + reflect.Int64: reflect.TypeOf(int64(0)), + reflect.Uint: reflect.TypeOf(uint(0)), + reflect.Uint8: reflect.TypeOf(uint8(0)), + reflect.Uint16: reflect.TypeOf(uint16(0)), + reflect.Uint32: reflect.TypeOf(uint32(0)), + reflect.Uint64: reflect.TypeOf(uint64(0)), + reflect.Float32: reflect.TypeOf(float32(0)), + reflect.Float64: reflect.TypeOf(float64(0)), + reflect.String: reflect.TypeOf(""), +} + +type underlyingTypeEncodePlan struct { + nextValueType reflect.Type + next EncodePlan +} + +func (plan *underlyingTypeEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *underlyingTypeEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(reflect.ValueOf(value).Convert(plan.nextValueType).Interface(), buf) +} + +// TryWrapFindUnderlyingTypeEncodePlan tries to convert to a Go builtin type. e.g. If value was of type MyString and +// MyString was defined as a string then a wrapper plan would be returned that converts MyString to string. +func TryWrapFindUnderlyingTypeEncodePlan(value any) (plan WrappedEncodePlanNextSetter, nextValue any, ok bool) { + if _, ok := value.(SkipUnderlyingTypePlanner); ok { + return nil, nil, false + } + + refValue := reflect.ValueOf(value) + + nextValueType := kindToTypes[refValue.Kind()] + if nextValueType != nil && refValue.Type() != nextValueType { + return &underlyingTypeEncodePlan{nextValueType: nextValueType}, refValue.Convert(nextValueType).Interface(), true + } + + return nil, nil, false +} + +type WrappedEncodePlanNextSetter interface { + SetNext(EncodePlan) + EncodePlan +} + +// TryWrapBuiltinTypeEncodePlan tries to wrap a builtin type with a wrapper that provides additional methods. e.g. If +// value was of type int32 then a wrapper plan would be returned that converts value to a type that implements +// Int64Valuer. +func TryWrapBuiltinTypeEncodePlan(value any) (plan WrappedEncodePlanNextSetter, nextValue any, ok bool) { + switch value := value.(type) { + case int8: + return &wrapInt8EncodePlan{}, int8Wrapper(value), true + case int16: + return &wrapInt16EncodePlan{}, int16Wrapper(value), true + case int32: + return &wrapInt32EncodePlan{}, int32Wrapper(value), true + case int64: + return &wrapInt64EncodePlan{}, int64Wrapper(value), true + case int: + return &wrapIntEncodePlan{}, intWrapper(value), true + case uint8: + return &wrapUint8EncodePlan{}, uint8Wrapper(value), true + case uint16: + return &wrapUint16EncodePlan{}, uint16Wrapper(value), true + case uint32: + return &wrapUint32EncodePlan{}, uint32Wrapper(value), true + case uint64: + return &wrapUint64EncodePlan{}, uint64Wrapper(value), true + case uint: + return &wrapUintEncodePlan{}, uintWrapper(value), true + case float32: + return &wrapFloat32EncodePlan{}, float32Wrapper(value), true + case float64: + return &wrapFloat64EncodePlan{}, float64Wrapper(value), true + case string: + return &wrapStringEncodePlan{}, stringWrapper(value), true + case time.Time: + return &wrapTimeEncodePlan{}, timeWrapper(value), true + case time.Duration: + return &wrapDurationEncodePlan{}, durationWrapper(value), true + case net.IPNet: + return &wrapNetIPNetEncodePlan{}, netIPNetWrapper(value), true + case net.IP: + return &wrapNetIPEncodePlan{}, netIPWrapper(value), true + case netip.Prefix: + return &wrapNetipPrefixEncodePlan{}, netipPrefixWrapper(value), true + case netip.Addr: + return &wrapNetipAddrEncodePlan{}, netipAddrWrapper(value), true + case map[string]*string: + return &wrapMapStringToPointerStringEncodePlan{}, mapStringToPointerStringWrapper(value), true + case map[string]string: + return &wrapMapStringToStringEncodePlan{}, mapStringToStringWrapper(value), true + case [16]byte: + return &wrapByte16EncodePlan{}, byte16Wrapper(value), true + case []byte: + return &wrapByteSliceEncodePlan{}, byteSliceWrapper(value), true + case fmt.Stringer: + return &wrapFmtStringerEncodePlan{}, fmtStringerWrapper{value}, true + } + + return nil, nil, false +} + +type wrapInt8EncodePlan struct { + next EncodePlan +} + +func (plan *wrapInt8EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapInt8EncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(int8Wrapper(value.(int8)), buf) +} + +type wrapInt16EncodePlan struct { + next EncodePlan +} + +func (plan *wrapInt16EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapInt16EncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(int16Wrapper(value.(int16)), buf) +} + +type wrapInt32EncodePlan struct { + next EncodePlan +} + +func (plan *wrapInt32EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapInt32EncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(int32Wrapper(value.(int32)), buf) +} + +type wrapInt64EncodePlan struct { + next EncodePlan +} + +func (plan *wrapInt64EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapInt64EncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(int64Wrapper(value.(int64)), buf) +} + +type wrapIntEncodePlan struct { + next EncodePlan +} + +func (plan *wrapIntEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapIntEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(intWrapper(value.(int)), buf) +} + +type wrapUint8EncodePlan struct { + next EncodePlan +} + +func (plan *wrapUint8EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapUint8EncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(uint8Wrapper(value.(uint8)), buf) +} + +type wrapUint16EncodePlan struct { + next EncodePlan +} + +func (plan *wrapUint16EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapUint16EncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(uint16Wrapper(value.(uint16)), buf) +} + +type wrapUint32EncodePlan struct { + next EncodePlan +} + +func (plan *wrapUint32EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapUint32EncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(uint32Wrapper(value.(uint32)), buf) +} + +type wrapUint64EncodePlan struct { + next EncodePlan +} + +func (plan *wrapUint64EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapUint64EncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(uint64Wrapper(value.(uint64)), buf) +} + +type wrapUintEncodePlan struct { + next EncodePlan +} + +func (plan *wrapUintEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapUintEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(uintWrapper(value.(uint)), buf) +} + +type wrapFloat32EncodePlan struct { + next EncodePlan +} + +func (plan *wrapFloat32EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapFloat32EncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(float32Wrapper(value.(float32)), buf) +} + +type wrapFloat64EncodePlan struct { + next EncodePlan +} + +func (plan *wrapFloat64EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapFloat64EncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(float64Wrapper(value.(float64)), buf) +} + +type wrapStringEncodePlan struct { + next EncodePlan +} + +func (plan *wrapStringEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapStringEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(stringWrapper(value.(string)), buf) +} + +type wrapTimeEncodePlan struct { + next EncodePlan +} + +func (plan *wrapTimeEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapTimeEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(timeWrapper(value.(time.Time)), buf) +} + +type wrapDurationEncodePlan struct { + next EncodePlan +} + +func (plan *wrapDurationEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapDurationEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(durationWrapper(value.(time.Duration)), buf) +} + +type wrapNetIPNetEncodePlan struct { + next EncodePlan +} + +func (plan *wrapNetIPNetEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapNetIPNetEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(netIPNetWrapper(value.(net.IPNet)), buf) +} + +type wrapNetIPEncodePlan struct { + next EncodePlan +} + +func (plan *wrapNetIPEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapNetIPEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(netIPWrapper(value.(net.IP)), buf) +} + +type wrapNetipPrefixEncodePlan struct { + next EncodePlan +} + +func (plan *wrapNetipPrefixEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapNetipPrefixEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(netipPrefixWrapper(value.(netip.Prefix)), buf) +} + +type wrapNetipAddrEncodePlan struct { + next EncodePlan +} + +func (plan *wrapNetipAddrEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapNetipAddrEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(netipAddrWrapper(value.(netip.Addr)), buf) +} + +type wrapMapStringToPointerStringEncodePlan struct { + next EncodePlan +} + +func (plan *wrapMapStringToPointerStringEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapMapStringToPointerStringEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(mapStringToPointerStringWrapper(value.(map[string]*string)), buf) +} + +type wrapMapStringToStringEncodePlan struct { + next EncodePlan +} + +func (plan *wrapMapStringToStringEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapMapStringToStringEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(mapStringToStringWrapper(value.(map[string]string)), buf) +} + +type wrapByte16EncodePlan struct { + next EncodePlan +} + +func (plan *wrapByte16EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapByte16EncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(byte16Wrapper(value.([16]byte)), buf) +} + +type wrapByteSliceEncodePlan struct { + next EncodePlan +} + +func (plan *wrapByteSliceEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapByteSliceEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(byteSliceWrapper(value.([]byte)), buf) +} + +type wrapFmtStringerEncodePlan struct { + next EncodePlan +} + +func (plan *wrapFmtStringerEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapFmtStringerEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(fmtStringerWrapper{value.(fmt.Stringer)}, buf) +} + +// TryWrapStructPlan tries to wrap a struct with a wrapper that implements CompositeIndexGetter. +func TryWrapStructEncodePlan(value any) (plan WrappedEncodePlanNextSetter, nextValue any, ok bool) { + if reflect.TypeOf(value).Kind() == reflect.Struct { + exportedFields := getExportedFieldValues(reflect.ValueOf(value)) + if len(exportedFields) == 0 { + return nil, nil, false + } + + w := structWrapper{ + s: value, + exportedFields: exportedFields, + } + return &wrapAnyStructEncodePlan{}, w, true + } + + return nil, nil, false +} + +type wrapAnyStructEncodePlan struct { + next EncodePlan +} + +func (plan *wrapAnyStructEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapAnyStructEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + w := structWrapper{ + s: value, + exportedFields: getExportedFieldValues(reflect.ValueOf(value)), + } + + return plan.next.Encode(w, buf) +} + +func getExportedFieldValues(structValue reflect.Value) []reflect.Value { + structType := structValue.Type() + exportedFields := make([]reflect.Value, 0, structValue.NumField()) + for i := 0; i < structType.NumField(); i++ { + sf := structType.Field(i) + if sf.IsExported() { + exportedFields = append(exportedFields, structValue.Field(i)) + } + } + + return exportedFields +} + +func TryWrapSliceEncodePlan(value any) (plan WrappedEncodePlanNextSetter, nextValue any, ok bool) { + // Avoid using reflect path for common types. + switch value := value.(type) { + case []int16: + return &wrapSliceEncodePlan[int16]{}, (FlatArray[int16])(value), true + case []int32: + return &wrapSliceEncodePlan[int32]{}, (FlatArray[int32])(value), true + case []int64: + return &wrapSliceEncodePlan[int64]{}, (FlatArray[int64])(value), true + case []float32: + return &wrapSliceEncodePlan[float32]{}, (FlatArray[float32])(value), true + case []float64: + return &wrapSliceEncodePlan[float64]{}, (FlatArray[float64])(value), true + case []string: + return &wrapSliceEncodePlan[string]{}, (FlatArray[string])(value), true + case []time.Time: + return &wrapSliceEncodePlan[time.Time]{}, (FlatArray[time.Time])(value), true + } + + if reflect.TypeOf(value).Kind() == reflect.Slice { + w := anySliceArrayReflect{ + slice: reflect.ValueOf(value), + } + return &wrapSliceEncodeReflectPlan{}, w, true + } + + return nil, nil, false +} + +type wrapSliceEncodePlan[T any] struct { + next EncodePlan +} + +func (plan *wrapSliceEncodePlan[T]) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapSliceEncodePlan[T]) Encode(value any, buf []byte) (newBuf []byte, err error) { + w := anySliceArrayReflect{ + slice: reflect.ValueOf(value), + } + + return plan.next.Encode(w, buf) +} + +type wrapSliceEncodeReflectPlan struct { + next EncodePlan +} + +func (plan *wrapSliceEncodeReflectPlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapSliceEncodeReflectPlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + w := anySliceArrayReflect{ + slice: reflect.ValueOf(value), + } + + return plan.next.Encode(w, buf) +} + +func TryWrapMultiDimSliceEncodePlan(value any) (plan WrappedEncodePlanNextSetter, nextValue any, ok bool) { + sliceValue := reflect.ValueOf(value) + if sliceValue.Kind() == reflect.Slice { + valueElemType := sliceValue.Type().Elem() + + if valueElemType.Kind() == reflect.Slice { + if !isRagged(sliceValue) { + w := anyMultiDimSliceArray{ + slice: reflect.ValueOf(value), + } + return &wrapMultiDimSliceEncodePlan{}, &w, true + } + } + } + + return nil, nil, false +} + +type wrapMultiDimSliceEncodePlan struct { + next EncodePlan +} + +func (plan *wrapMultiDimSliceEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapMultiDimSliceEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + w := anyMultiDimSliceArray{ + slice: reflect.ValueOf(value), + } + + return plan.next.Encode(&w, buf) +} + +func newEncodeError(value any, m *Map, oid uint32, formatCode int16, err error) error { + var format string + switch formatCode { + case TextFormatCode: + format = "text" + case BinaryFormatCode: + format = "binary" + default: + format = fmt.Sprintf("unknown (%d)", formatCode) + } + + var dataTypeName string + if t, ok := m.oidToType[oid]; ok { + dataTypeName = t.Name + } else { + dataTypeName = "unknown type" + } + + return fmt.Errorf("unable to encode %#v into %s format for %s (OID %d): %s", value, format, dataTypeName, oid, err) +} + +// Encode appends the encoded bytes of value to buf. If value is the SQL value NULL then append nothing and return +// (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data +// written. +func (m *Map) Encode(oid uint32, formatCode int16, value any, buf []byte) (newBuf []byte, err error) { + if value == nil { + return nil, nil + } + + plan := m.PlanEncode(oid, formatCode, value) + if plan == nil { + if dv, ok := value.(driver.Valuer); ok { + if dv == nil { + return nil, nil + } + v, err := dv.Value() + if err != nil { + return nil, err + } + return m.Encode(oid, formatCode, v, buf) + } + + return nil, newEncodeError(value, m, oid, formatCode, errors.New("cannot find encode plan")) + } + + newBuf, err = plan.Encode(value, buf) + if err != nil { + return nil, newEncodeError(value, m, oid, formatCode, err) + } + + return newBuf, nil +} + +// SQLScanner returns a database/sql.Scanner for v. This is necessary for types like Array[T] and Range[T] where the +// type needs assistance from Map to implement the sql.Scanner interface. It is not necessary for types like Box that +// implement sql.Scanner directly. +// +// This uses the type of v to look up the PostgreSQL OID that v presumably came from. This means v must be registered +// with m by calling RegisterDefaultPgType. +func (m *Map) SQLScanner(v any) sql.Scanner { + if s, ok := v.(sql.Scanner); ok { + return s + } + + return &sqlScannerWrapper{m: m, v: v} +} + +type sqlScannerWrapper struct { + m *Map + v any +} + +func (w *sqlScannerWrapper) Scan(src any) error { + t, ok := w.m.TypeForValue(w.v) + if !ok { + return fmt.Errorf("cannot convert to sql.Scanner: cannot find registered type for %T", w.v) + } + + var bufSrc []byte + switch src := src.(type) { + case string: + bufSrc = []byte(src) + case []byte: + bufSrc = src + default: + bufSrc = []byte(fmt.Sprint(bufSrc)) + } + + return w.m.Scan(t.OID, TextFormatCode, bufSrc, w.v) +} diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go new file mode 100644 index 00000000..11ae39e0 --- /dev/null +++ b/pgtype/pgtype_test.go @@ -0,0 +1,344 @@ +package pgtype_test + +import ( + "context" + "database/sql" + "errors" + "net" + "os" + "regexp" + "strconv" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" + _ "github.com/jackc/pgx/v5/stdlib" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var defaultConnTestRunner pgxtest.ConnTestRunner + +func init() { + defaultConnTestRunner = pgxtest.DefaultConnTestRunner() + defaultConnTestRunner.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + return config + } +} + +// Test for renamed types +type _string string +type _bool bool +type _int8 int8 +type _int16 int16 +type _int16Slice []int16 +type _int32Slice []int32 +type _int64Slice []int64 +type _float32Slice []float32 +type _float64Slice []float64 +type _byteSlice []byte + +func mustParseCIDR(t testing.TB, s string) *net.IPNet { + _, ipnet, err := net.ParseCIDR(s) + if err != nil { + t.Fatal(err) + } + + return ipnet +} + +func mustParseInet(t testing.TB, s string) *net.IPNet { + ip, ipnet, err := net.ParseCIDR(s) + if err == nil { + if ipv4 := ip.To4(); ipv4 != nil { + ipnet.IP = ipv4 + } else { + ipnet.IP = ip + } + return ipnet + } + + // May be bare IP address. + // + ip = net.ParseIP(s) + if ip == nil { + t.Fatal(errors.New("unable to parse inet address")) + } + ipnet = &net.IPNet{IP: ip, Mask: net.CIDRMask(128, 128)} + if ipv4 := ip.To4(); ipv4 != nil { + ipnet.IP = ipv4 + ipnet.Mask = net.CIDRMask(32, 32) + } + return ipnet +} + +func mustParseMacaddr(t testing.TB, s string) net.HardwareAddr { + addr, err := net.ParseMAC(s) + if err != nil { + t.Fatal(err) + } + + return addr +} + +func skipCockroachDB(t testing.TB, msg string) { + conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + t.Fatal(err) + } + defer conn.Close(context.Background()) + + if conn.PgConn().ParameterStatus("crdb_version") != "" { + t.Skip(msg) + } +} + +func skipPostgreSQLVersionLessThan(t testing.TB, minVersion int64) { + conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + t.Fatal(err) + } + defer conn.Close(context.Background()) + + serverVersionStr := conn.PgConn().ParameterStatus("server_version") + serverVersionStr = regexp.MustCompile(`^[0-9]+`).FindString(serverVersionStr) + // if not PostgreSQL do nothing + if serverVersionStr == "" { + return + } + + serverVersion, err := strconv.ParseInt(serverVersionStr, 10, 64) + require.NoError(t, err) + + if serverVersion < minVersion { + t.Skipf("Test requires PostgreSQL v%d+", minVersion) + } +} + +func TestMapScanNilIsNoOp(t *testing.T) { + m := pgtype.NewMap() + + err := m.Scan(pgtype.TextOID, pgx.TextFormatCode, []byte("foo"), nil) + assert.NoError(t, err) +} + +func TestMapScanTextFormatInterfacePtr(t *testing.T) { + m := pgtype.NewMap() + var got any + err := m.Scan(pgtype.TextOID, pgx.TextFormatCode, []byte("foo"), &got) + require.NoError(t, err) + assert.Equal(t, "foo", got) +} + +func TestMapScanTextFormatNonByteaIntoByteSlice(t *testing.T) { + m := pgtype.NewMap() + var got []byte + err := m.Scan(pgtype.JSONBOID, pgx.TextFormatCode, []byte("{}"), &got) + require.NoError(t, err) + assert.Equal(t, []byte("{}"), got) +} + +func TestMapScanBinaryFormatInterfacePtr(t *testing.T) { + m := pgtype.NewMap() + var got any + err := m.Scan(pgtype.TextOID, pgx.BinaryFormatCode, []byte("foo"), &got) + require.NoError(t, err) + assert.Equal(t, "foo", got) +} + +func TestMapScanUnknownOIDToStringsAndBytes(t *testing.T) { + unknownOID := uint32(999999) + srcBuf := []byte("foo") + m := pgtype.NewMap() + + var s string + err := m.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &s) + assert.NoError(t, err) + assert.Equal(t, "foo", s) + + var rs _string + err = m.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &rs) + assert.NoError(t, err) + assert.Equal(t, "foo", string(rs)) + + var b []byte + err = m.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &b) + assert.NoError(t, err) + assert.Equal(t, []byte("foo"), b) + + var rb _byteSlice + err = m.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &rb) + assert.NoError(t, err) + assert.Equal(t, []byte("foo"), []byte(rb)) +} + +func TestMapScanPointerToNilStructDoesNotCrash(t *testing.T) { + m := pgtype.NewMap() + + type myStruct struct{} + var p *myStruct + err := m.Scan(0, pgx.TextFormatCode, []byte("(foo,bar)"), &p) + require.NotNil(t, err) +} + +func TestMapScanUnknownOIDTextFormat(t *testing.T) { + m := pgtype.NewMap() + + var n int32 + err := m.Scan(0, pgx.TextFormatCode, []byte("123"), &n) + assert.NoError(t, err) + assert.EqualValues(t, 123, n) +} + +func TestMapScanUnknownOIDIntoSQLScanner(t *testing.T) { + m := pgtype.NewMap() + + var s sql.NullString + err := m.Scan(0, pgx.TextFormatCode, []byte(nil), &s) + assert.NoError(t, err) + assert.Equal(t, "", s.String) + assert.False(t, s.Valid) +} + +type pgCustomInt int64 + +func (ci *pgCustomInt) Scan(src interface{}) error { + *ci = pgCustomInt(src.(int64)) + return nil +} + +func TestScanPlanBinaryInt32ScanScanner(t *testing.T) { + m := pgtype.NewMap() + src := []byte{0, 42} + var v pgCustomInt + + plan := m.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &v) + err := plan.Scan(src, &v) + require.NoError(t, err) + require.EqualValues(t, 42, v) + + ptr := new(pgCustomInt) + plan = m.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &ptr) + err = plan.Scan(src, &ptr) + require.NoError(t, err) + require.EqualValues(t, 42, *ptr) + + ptr = new(pgCustomInt) + err = plan.Scan(nil, &ptr) + require.NoError(t, err) + assert.Nil(t, ptr) + + ptr = nil + plan = m.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &ptr) + err = plan.Scan(src, &ptr) + require.NoError(t, err) + require.EqualValues(t, 42, *ptr) + + ptr = nil + plan = m.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &ptr) + err = plan.Scan(nil, &ptr) + require.NoError(t, err) + assert.Nil(t, ptr) +} + +// Test for https://github.com/jackc/pgtype/issues/164 +func TestScanPlanInterface(t *testing.T) { + m := pgtype.NewMap() + src := []byte{0, 42} + var v interface{} + plan := m.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, v) + err := plan.Scan(src, v) + assert.Error(t, err) +} + +// https://github.com/jackc/pgx/issues/1263 +func TestMapScanPtrToPtrToSlice(t *testing.T) { + m := pgtype.NewMap() + src := []byte("{foo,bar}") + var v *[]string + plan := m.PlanScan(pgtype.TextArrayOID, pgtype.TextFormatCode, &v) + err := plan.Scan(src, &v) + require.NoError(t, err) + require.Equal(t, []string{"foo", "bar"}, *v) +} + +func BenchmarkMapScanInt4IntoBinaryDecoder(b *testing.B) { + m := pgtype.NewMap() + src := []byte{0, 0, 0, 42} + var v pgtype.Int4 + + for i := 0; i < b.N; i++ { + v = pgtype.Int4{} + err := m.Scan(pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v) + if err != nil { + b.Fatal(err) + } + if v != (pgtype.Int4{Int32: 42, Valid: true}) { + b.Fatal("scan failed due to bad value") + } + } +} + +func BenchmarkMapScanInt4IntoGoInt32(b *testing.B) { + m := pgtype.NewMap() + src := []byte{0, 0, 0, 42} + var v int32 + + for i := 0; i < b.N; i++ { + v = 0 + err := m.Scan(pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v) + if err != nil { + b.Fatal(err) + } + if v != 42 { + b.Fatal("scan failed due to bad value") + } + } +} + +func BenchmarkScanPlanScanInt4IntoBinaryDecoder(b *testing.B) { + m := pgtype.NewMap() + src := []byte{0, 0, 0, 42} + var v pgtype.Int4 + + plan := m.PlanScan(pgtype.Int4OID, pgtype.BinaryFormatCode, &v) + + for i := 0; i < b.N; i++ { + v = pgtype.Int4{} + err := plan.Scan(src, &v) + if err != nil { + b.Fatal(err) + } + if v != (pgtype.Int4{Int32: 42, Valid: true}) { + b.Fatal("scan failed due to bad value") + } + } +} + +func BenchmarkScanPlanScanInt4IntoGoInt32(b *testing.B) { + m := pgtype.NewMap() + src := []byte{0, 0, 0, 42} + var v int32 + + plan := m.PlanScan(pgtype.Int4OID, pgtype.BinaryFormatCode, &v) + + for i := 0; i < b.N; i++ { + v = 0 + err := plan.Scan(src, &v) + if err != nil { + b.Fatal(err) + } + if v != 42 { + b.Fatal("scan failed due to bad value") + } + } +} + +func isExpectedEq(a any) func(any) bool { + return func(v any) bool { + return a == v + } +} diff --git a/pgtype/point.go b/pgtype/point.go new file mode 100644 index 00000000..cfa5a9f1 --- /dev/null +++ b/pgtype/point.go @@ -0,0 +1,266 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + "strings" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type Vec2 struct { + X float64 + Y float64 +} + +type PointScanner interface { + ScanPoint(v Point) error +} + +type PointValuer interface { + PointValue() (Point, error) +} + +type Point struct { + P Vec2 + Valid bool +} + +func (p *Point) ScanPoint(v Point) error { + *p = v + return nil +} + +func (p Point) PointValue() (Point, error) { + return p, nil +} + +func parsePoint(src []byte) (*Point, error) { + if src == nil || bytes.Compare(src, []byte("null")) == 0 { + return &Point{}, nil + } + + if len(src) < 5 { + return nil, fmt.Errorf("invalid length for point: %v", len(src)) + } + if src[0] == '"' && src[len(src)-1] == '"' { + src = src[1 : len(src)-1] + } + parts := strings.SplitN(string(src[1:len(src)-1]), ",", 2) + if len(parts) < 2 { + return nil, fmt.Errorf("invalid format for point") + } + + x, err := strconv.ParseFloat(parts[0], 64) + if err != nil { + return nil, err + } + + y, err := strconv.ParseFloat(parts[1], 64) + if err != nil { + return nil, err + } + + return &Point{P: Vec2{x, y}, Valid: true}, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Point) Scan(src any) error { + if src == nil { + *dst = Point{} + return nil + } + + switch src := src.(type) { + case string: + return scanPlanTextAnyToPointScanner{}.Scan([]byte(src), dst) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Point) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + + buf, err := PointCodec{}.PlanEncode(nil, 0, TextFormatCode, src).Encode(src, nil) + if err != nil { + return nil, err + } + return string(buf), err +} + +func (src Point) MarshalJSON() ([]byte, error) { + if !src.Valid { + return []byte("null"), nil + } + + var buff bytes.Buffer + buff.WriteByte('"') + buff.WriteString(fmt.Sprintf("(%g,%g)", src.P.X, src.P.Y)) + buff.WriteByte('"') + return buff.Bytes(), nil +} + +func (dst *Point) UnmarshalJSON(point []byte) error { + p, err := parsePoint(point) + if err != nil { + return err + } + *dst = *p + return nil +} + +type PointCodec struct{} + +func (PointCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (PointCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (PointCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(PointValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return encodePlanPointCodecBinary{} + case TextFormatCode: + return encodePlanPointCodecText{} + } + + return nil +} + +type encodePlanPointCodecBinary struct{} + +func (encodePlanPointCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + point, err := value.(PointValuer).PointValue() + if err != nil { + return nil, err + } + + if !point.Valid { + return nil, nil + } + + buf = pgio.AppendUint64(buf, math.Float64bits(point.P.X)) + buf = pgio.AppendUint64(buf, math.Float64bits(point.P.Y)) + return buf, nil +} + +type encodePlanPointCodecText struct{} + +func (encodePlanPointCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { + point, err := value.(PointValuer).PointValue() + if err != nil { + return nil, err + } + + if !point.Valid { + return nil, nil + } + + return append(buf, fmt.Sprintf(`(%s,%s)`, + strconv.FormatFloat(point.P.X, 'f', -1, 64), + strconv.FormatFloat(point.P.Y, 'f', -1, 64), + )...), nil +} + +func (PointCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case PointScanner: + return scanPlanBinaryPointToPointScanner{} + } + case TextFormatCode: + switch target.(type) { + case PointScanner: + return scanPlanTextAnyToPointScanner{} + } + } + + return nil +} + +func (c PointCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) +} + +func (c PointCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var point Point + err := codecScan(c, m, oid, format, src, &point) + if err != nil { + return nil, err + } + return point, nil +} + +type scanPlanBinaryPointToPointScanner struct{} + +func (scanPlanBinaryPointToPointScanner) Scan(src []byte, dst any) error { + scanner := (dst).(PointScanner) + + if src == nil { + return scanner.ScanPoint(Point{}) + } + + if len(src) != 16 { + return fmt.Errorf("invalid length for point: %v", len(src)) + } + + x := binary.BigEndian.Uint64(src) + y := binary.BigEndian.Uint64(src[8:]) + + return scanner.ScanPoint(Point{ + P: Vec2{math.Float64frombits(x), math.Float64frombits(y)}, + Valid: true, + }) +} + +type scanPlanTextAnyToPointScanner struct{} + +func (scanPlanTextAnyToPointScanner) Scan(src []byte, dst any) error { + scanner := (dst).(PointScanner) + + if src == nil { + return scanner.ScanPoint(Point{}) + } + + if len(src) < 5 { + return fmt.Errorf("invalid length for point: %v", len(src)) + } + + parts := strings.SplitN(string(src[1:len(src)-1]), ",", 2) + if len(parts) < 2 { + return fmt.Errorf("invalid format for point") + } + + x, err := strconv.ParseFloat(parts[0], 64) + if err != nil { + return err + } + + y, err := strconv.ParseFloat(parts[1], 64) + if err != nil { + return err + } + + return scanner.ScanPoint(Point{P: Vec2{x, y}, Valid: true}) +} diff --git a/pgtype/point_test.go b/pgtype/point_test.go new file mode 100644 index 00000000..336f1a47 --- /dev/null +++ b/pgtype/point_test.go @@ -0,0 +1,102 @@ +package pgtype_test + +import ( + "context" + "reflect" + "testing" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/require" +) + +func TestPointCodec(t *testing.T) { + skipCockroachDB(t, "Server does not support type point") + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "point", []pgxtest.ValueRoundTripTest{ + { + pgtype.Point{P: pgtype.Vec2{1.234, 5.6789012345}, Valid: true}, + new(pgtype.Point), + isExpectedEq(pgtype.Point{P: pgtype.Vec2{1.234, 5.6789012345}, Valid: true}), + }, + { + pgtype.Point{P: pgtype.Vec2{-1.234, -5.6789}, Valid: true}, + new(pgtype.Point), + isExpectedEq(pgtype.Point{P: pgtype.Vec2{-1.234, -5.6789}, Valid: true}), + }, + {pgtype.Point{}, new(pgtype.Point), isExpectedEq(pgtype.Point{})}, + {nil, new(pgtype.Point), isExpectedEq(pgtype.Point{})}, + }) +} + +func TestPoint_MarshalJSON(t *testing.T) { + tests := []struct { + name string + point pgtype.Point + want []byte + }{ + { + name: "second", + point: pgtype.Point{ + P: pgtype.Vec2{X: 12.245, Y: 432.12}, + Valid: true, + }, + want: []byte(`"(12.245,432.12)"`), + }, + { + name: "third", + point: pgtype.Point{ + P: pgtype.Vec2{}, + }, + want: []byte("null"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.point.MarshalJSON() + require.NoError(t, err) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("MarshalJSON() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestPoint_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + valid bool + arg []byte + wantErr bool + }{ + { + name: "first", + valid: true, + arg: []byte(`"(123.123,54.12)"`), + wantErr: false, + }, + { + name: "second", + valid: false, + arg: []byte(`"(123.123,54.1sad2)"`), + wantErr: true, + }, + { + name: "third", + valid: false, + arg: []byte("null"), + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dst := &pgtype.Point{} + if err := dst.UnmarshalJSON(tt.arg); (err != nil) != tt.wantErr { + t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + if dst.Valid != tt.valid { + t.Errorf("Valid mismatch: %v != %v", dst.Valid, tt.valid) + } + }) + } +} diff --git a/pgtype/polygon.go b/pgtype/polygon.go new file mode 100644 index 00000000..04b0ba6b --- /dev/null +++ b/pgtype/polygon.go @@ -0,0 +1,253 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + "strings" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type PolygonScanner interface { + ScanPolygon(v Polygon) error +} + +type PolygonValuer interface { + PolygonValue() (Polygon, error) +} + +type Polygon struct { + P []Vec2 + Valid bool +} + +func (p *Polygon) ScanPolygon(v Polygon) error { + *p = v + return nil +} + +func (p Polygon) PolygonValue() (Polygon, error) { + return p, nil +} + +// Scan implements the database/sql Scanner interface. +func (p *Polygon) Scan(src any) error { + if src == nil { + *p = Polygon{} + return nil + } + + switch src := src.(type) { + case string: + return scanPlanTextAnyToPolygonScanner{}.Scan([]byte(src), p) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (p Polygon) Value() (driver.Value, error) { + if !p.Valid { + return nil, nil + } + + buf, err := PolygonCodec{}.PlanEncode(nil, 0, TextFormatCode, p).Encode(p, nil) + if err != nil { + return nil, err + } + + return string(buf), err +} + +type PolygonCodec struct{} + +func (PolygonCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (PolygonCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (PolygonCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(PolygonValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return encodePlanPolygonCodecBinary{} + case TextFormatCode: + return encodePlanPolygonCodecText{} + } + + return nil +} + +type encodePlanPolygonCodecBinary struct{} + +func (encodePlanPolygonCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + polygon, err := value.(PolygonValuer).PolygonValue() + if err != nil { + return nil, err + } + + if !polygon.Valid { + return nil, nil + } + + buf = pgio.AppendInt32(buf, int32(len(polygon.P))) + + for _, p := range polygon.P { + buf = pgio.AppendUint64(buf, math.Float64bits(p.X)) + buf = pgio.AppendUint64(buf, math.Float64bits(p.Y)) + } + + return buf, nil +} + +type encodePlanPolygonCodecText struct{} + +func (encodePlanPolygonCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { + polygon, err := value.(PolygonValuer).PolygonValue() + if err != nil { + return nil, err + } + + if !polygon.Valid { + return nil, nil + } + + buf = append(buf, '(') + + for i, p := range polygon.P { + if i > 0 { + buf = append(buf, ',') + } + buf = append(buf, fmt.Sprintf(`(%s,%s)`, + strconv.FormatFloat(p.X, 'f', -1, 64), + strconv.FormatFloat(p.Y, 'f', -1, 64), + )...) + } + + buf = append(buf, ')') + + return buf, nil +} + +func (PolygonCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case PolygonScanner: + return scanPlanBinaryPolygonToPolygonScanner{} + } + case TextFormatCode: + switch target.(type) { + case PolygonScanner: + return scanPlanTextAnyToPolygonScanner{} + } + } + + return nil +} + +type scanPlanBinaryPolygonToPolygonScanner struct{} + +func (scanPlanBinaryPolygonToPolygonScanner) Scan(src []byte, dst any) error { + scanner := (dst).(PolygonScanner) + + if src == nil { + return scanner.ScanPolygon(Polygon{}) + } + + if len(src) < 5 { + return fmt.Errorf("invalid length for polygon: %v", len(src)) + } + + pointCount := int(binary.BigEndian.Uint32(src)) + rp := 4 + + if 4+pointCount*16 != len(src) { + return fmt.Errorf("invalid length for Polygon with %d points: %v", pointCount, len(src)) + } + + points := make([]Vec2, pointCount) + for i := 0; i < len(points); i++ { + x := binary.BigEndian.Uint64(src[rp:]) + rp += 8 + y := binary.BigEndian.Uint64(src[rp:]) + rp += 8 + points[i] = Vec2{math.Float64frombits(x), math.Float64frombits(y)} + } + + return scanner.ScanPolygon(Polygon{ + P: points, + Valid: true, + }) +} + +type scanPlanTextAnyToPolygonScanner struct{} + +func (scanPlanTextAnyToPolygonScanner) Scan(src []byte, dst any) error { + scanner := (dst).(PolygonScanner) + + if src == nil { + return scanner.ScanPolygon(Polygon{}) + } + + if len(src) < 7 { + return fmt.Errorf("invalid length for Polygon: %v", len(src)) + } + + points := make([]Vec2, 0) + + str := string(src[2:]) + + for { + end := strings.IndexByte(str, ',') + x, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1:] + end = strings.IndexByte(str, ')') + + y, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + points = append(points, Vec2{x, y}) + + if end+3 < len(str) { + str = str[end+3:] + } else { + break + } + } + + return scanner.ScanPolygon(Polygon{P: points, Valid: true}) +} + +func (c PolygonCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) +} + +func (c PolygonCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var polygon Polygon + err := codecScan(c, m, oid, format, src, &polygon) + if err != nil { + return nil, err + } + return polygon, nil +} diff --git a/pgtype/polygon_test.go b/pgtype/polygon_test.go new file mode 100644 index 00000000..5ddbc166 --- /dev/null +++ b/pgtype/polygon_test.go @@ -0,0 +1,59 @@ +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" +) + +func isExpectedEqPolygon(a any) func(any) bool { + return func(v any) bool { + ap := a.(pgtype.Polygon) + vp := v.(pgtype.Polygon) + + if !(ap.Valid == vp.Valid && len(ap.P) == len(vp.P)) { + return false + } + + for i := range ap.P { + if ap.P[i] != vp.P[i] { + return false + } + } + + return true + } +} + +func TestPolygonTranscode(t *testing.T) { + skipCockroachDB(t, "Server does not support type polygon") + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "polygon", []pgxtest.ValueRoundTripTest{ + { + pgtype.Polygon{ + P: []pgtype.Vec2{{3.14, 1.678901234}, {7.1, 5.234}, {5.0, 3.234}}, + Valid: true, + }, + new(pgtype.Polygon), + isExpectedEqPolygon(pgtype.Polygon{ + P: []pgtype.Vec2{{3.14, 1.678901234}, {7.1, 5.234}, {5.0, 3.234}}, + Valid: true, + }), + }, + { + pgtype.Polygon{ + P: []pgtype.Vec2{{3.14, -1.678}, {7.1, -5.234}, {23.1, 9.34}}, + Valid: true, + }, + new(pgtype.Polygon), + isExpectedEqPolygon(pgtype.Polygon{ + P: []pgtype.Vec2{{3.14, -1.678}, {7.1, -5.234}, {23.1, 9.34}}, + Valid: true, + }), + }, + {pgtype.Polygon{}, new(pgtype.Polygon), isExpectedEqPolygon(pgtype.Polygon{})}, + {nil, new(pgtype.Polygon), isExpectedEqPolygon(pgtype.Polygon{})}, + }) +} diff --git a/pgtype/qchar.go b/pgtype/qchar.go new file mode 100644 index 00000000..fc40a5b2 --- /dev/null +++ b/pgtype/qchar.go @@ -0,0 +1,141 @@ +package pgtype + +import ( + "database/sql/driver" + "fmt" + "math" +) + +// QCharCodec is for PostgreSQL's special 8-bit-only "char" type more akin to the C +// language's char type, or Go's byte type. (Note that the name in PostgreSQL +// itself is "char", in double-quotes, and not char.) It gets used a lot in +// PostgreSQL's system tables to hold a single ASCII character value (eg +// pg_class.relkind). It is named Qchar for quoted char to disambiguate from SQL +// standard type char. +type QCharCodec struct{} + +func (QCharCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (QCharCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (QCharCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch format { + case TextFormatCode, BinaryFormatCode: + switch value.(type) { + case byte: + return encodePlanQcharCodecByte{} + case rune: + return encodePlanQcharCodecRune{} + } + } + + return nil +} + +type encodePlanQcharCodecByte struct{} + +func (encodePlanQcharCodecByte) Encode(value any, buf []byte) (newBuf []byte, err error) { + b := value.(byte) + buf = append(buf, b) + return buf, nil +} + +type encodePlanQcharCodecRune struct{} + +func (encodePlanQcharCodecRune) Encode(value any, buf []byte) (newBuf []byte, err error) { + r := value.(rune) + if r > math.MaxUint8 { + return nil, fmt.Errorf(`%v cannot be encoded to "char"`, r) + } + b := byte(r) + buf = append(buf, b) + return buf, nil +} + +func (QCharCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case TextFormatCode, BinaryFormatCode: + switch target.(type) { + case *byte: + return scanPlanQcharCodecByte{} + case *rune: + return scanPlanQcharCodecRune{} + } + } + + return nil +} + +type scanPlanQcharCodecByte struct{} + +func (scanPlanQcharCodecByte) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) > 1 { + return fmt.Errorf(`invalid length for "char": %v`, len(src)) + } + + b := dst.(*byte) + // In the text format the zero value is returned as a zero byte value instead of 0 + if len(src) == 0 { + *b = 0 + } else { + *b = src[0] + } + + return nil +} + +type scanPlanQcharCodecRune struct{} + +func (scanPlanQcharCodecRune) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) > 1 { + return fmt.Errorf(`invalid length for "char": %v`, len(src)) + } + + r := dst.(*rune) + // In the text format the zero value is returned as a zero byte value instead of 0 + if len(src) == 0 { + *r = 0 + } else { + *r = rune(src[0]) + } + + return nil +} + +func (c QCharCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + var r rune + err := codecScan(c, m, oid, format, src, &r) + if err != nil { + return nil, err + } + return string(r), nil +} + +func (c QCharCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var r rune + err := codecScan(c, m, oid, format, src, &r) + if err != nil { + return nil, err + } + return r, nil +} diff --git a/pgtype/qchar_test.go b/pgtype/qchar_test.go new file mode 100644 index 00000000..da00b89e --- /dev/null +++ b/pgtype/qchar_test.go @@ -0,0 +1,24 @@ +package pgtype_test + +import ( + "context" + "math" + "testing" + + "github.com/jackc/pgx/v5/pgxtest" +) + +func TestQcharTranscode(t *testing.T) { + skipCockroachDB(t, "Server does not support qchar") + + var tests []pgxtest.ValueRoundTripTest + for i := 0; i <= math.MaxUint8; i++ { + tests = append(tests, pgxtest.ValueRoundTripTest{rune(i), new(rune), isExpectedEq(rune(i))}) + tests = append(tests, pgxtest.ValueRoundTripTest{byte(i), new(byte), isExpectedEq(byte(i))}) + } + tests = append(tests, pgxtest.ValueRoundTripTest{nil, new(*rune), isExpectedEq((*rune)(nil))}) + tests = append(tests, pgxtest.ValueRoundTripTest{nil, new(*byte), isExpectedEq((*byte)(nil))}) + + // Can only test with known OIDs as rune and byte would be considered numbers. + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, `"char"`, tests) +} diff --git a/pgtype/range.go b/pgtype/range.go new file mode 100644 index 00000000..8f408f9f --- /dev/null +++ b/pgtype/range.go @@ -0,0 +1,322 @@ +package pgtype + +import ( + "bytes" + "encoding/binary" + "fmt" +) + +type BoundType byte + +const ( + Inclusive = BoundType('i') + Exclusive = BoundType('e') + Unbounded = BoundType('U') + Empty = BoundType('E') +) + +func (bt BoundType) String() string { + return string(bt) +} + +type untypedTextRange struct { + Lower string + Upper string + LowerType BoundType + UpperType BoundType +} + +func parseUntypedTextRange(src string) (*untypedTextRange, error) { + utr := &untypedTextRange{} + if src == "empty" { + utr.LowerType = Empty + utr.UpperType = Empty + return utr, nil + } + + buf := bytes.NewBufferString(src) + + skipWhitespace(buf) + + r, _, err := buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid lower bound: %v", err) + } + switch r { + case '(': + utr.LowerType = Exclusive + case '[': + utr.LowerType = Inclusive + default: + return nil, fmt.Errorf("missing lower bound, instead got: %v", string(r)) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid lower value: %v", err) + } + buf.UnreadRune() + + if r == ',' { + utr.LowerType = Unbounded + } else { + utr.Lower, err = rangeParseValue(buf) + if err != nil { + return nil, fmt.Errorf("invalid lower value: %v", err) + } + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("missing range separator: %v", err) + } + if r != ',' { + return nil, fmt.Errorf("missing range separator: %v", r) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid upper value: %v", err) + } + + if r == ')' || r == ']' { + utr.UpperType = Unbounded + } else { + buf.UnreadRune() + utr.Upper, err = rangeParseValue(buf) + if err != nil { + return nil, fmt.Errorf("invalid upper value: %v", err) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("missing upper bound: %v", err) + } + switch r { + case ')': + utr.UpperType = Exclusive + case ']': + utr.UpperType = Inclusive + default: + return nil, fmt.Errorf("missing upper bound, instead got: %v", string(r)) + } + } + + skipWhitespace(buf) + + if buf.Len() > 0 { + return nil, fmt.Errorf("unexpected trailing data: %v", buf.String()) + } + + return utr, nil +} + +func rangeParseValue(buf *bytes.Buffer) (string, error) { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + if r == '"' { + return rangeParseQuotedValue(buf) + } + buf.UnreadRune() + + s := &bytes.Buffer{} + + for { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + + switch r { + case '\\': + r, _, err = buf.ReadRune() + if err != nil { + return "", err + } + case ',', '[', ']', '(', ')': + buf.UnreadRune() + return s.String(), nil + } + + s.WriteRune(r) + } +} + +func rangeParseQuotedValue(buf *bytes.Buffer) (string, error) { + s := &bytes.Buffer{} + + for { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + + switch r { + case '\\': + r, _, err = buf.ReadRune() + if err != nil { + return "", err + } + case '"': + r, _, err = buf.ReadRune() + if err != nil { + return "", err + } + if r != '"' { + buf.UnreadRune() + return s.String(), nil + } + } + s.WriteRune(r) + } +} + +type untypedBinaryRange struct { + Lower []byte + Upper []byte + LowerType BoundType + UpperType BoundType +} + +// 0 = () = 00000 +// 1 = empty = 00001 +// 2 = [) = 00010 +// 4 = (] = 00100 +// 6 = [] = 00110 +// 8 = ) = 01000 +// 12 = ] = 01100 +// 16 = ( = 10000 +// 18 = [ = 10010 +// 24 = = 11000 + +const emptyMask = 1 +const lowerInclusiveMask = 2 +const upperInclusiveMask = 4 +const lowerUnboundedMask = 8 +const upperUnboundedMask = 16 + +func parseUntypedBinaryRange(src []byte) (*untypedBinaryRange, error) { + ubr := &untypedBinaryRange{} + + if len(src) == 0 { + return nil, fmt.Errorf("range too short: %v", len(src)) + } + + rangeType := src[0] + rp := 1 + + if rangeType&emptyMask > 0 { + if len(src[rp:]) > 0 { + return nil, fmt.Errorf("unexpected trailing bytes parsing empty range: %v", len(src[rp:])) + } + ubr.LowerType = Empty + ubr.UpperType = Empty + return ubr, nil + } + + if rangeType&lowerInclusiveMask > 0 { + ubr.LowerType = Inclusive + } else if rangeType&lowerUnboundedMask > 0 { + ubr.LowerType = Unbounded + } else { + ubr.LowerType = Exclusive + } + + if rangeType&upperInclusiveMask > 0 { + ubr.UpperType = Inclusive + } else if rangeType&upperUnboundedMask > 0 { + ubr.UpperType = Unbounded + } else { + ubr.UpperType = Exclusive + } + + if ubr.LowerType == Unbounded && ubr.UpperType == Unbounded { + if len(src[rp:]) > 0 { + return nil, fmt.Errorf("unexpected trailing bytes parsing unbounded range: %v", len(src[rp:])) + } + return ubr, nil + } + + if len(src[rp:]) < 4 { + return nil, fmt.Errorf("too few bytes for size: %v", src[rp:]) + } + valueLen := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + val := src[rp : rp+valueLen] + rp += valueLen + + if ubr.LowerType != Unbounded { + ubr.Lower = val + } else { + ubr.Upper = val + if len(src[rp:]) > 0 { + return nil, fmt.Errorf("unexpected trailing bytes parsing range: %v", len(src[rp:])) + } + return ubr, nil + } + + if ubr.UpperType != Unbounded { + if len(src[rp:]) < 4 { + return nil, fmt.Errorf("too few bytes for size: %v", src[rp:]) + } + valueLen := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + ubr.Upper = src[rp : rp+valueLen] + rp += valueLen + } + + if len(src[rp:]) > 0 { + return nil, fmt.Errorf("unexpected trailing bytes parsing range: %v", len(src[rp:])) + } + + return ubr, nil + +} + +// Range is a generic range type. +type Range[T any] struct { + Lower T + Upper T + LowerType BoundType + UpperType BoundType + Valid bool +} + +func (r Range[T]) IsNull() bool { + return !r.Valid +} + +func (r Range[T]) BoundTypes() (lower, upper BoundType) { + return r.LowerType, r.UpperType +} + +func (r Range[T]) Bounds() (lower, upper any) { + return &r.Lower, &r.Upper +} + +func (r *Range[T]) ScanNull() error { + *r = Range[T]{} + return nil +} + +func (r *Range[T]) ScanBounds() (lowerTarget, upperTarget any) { + return &r.Lower, &r.Upper +} + +func (r *Range[T]) SetBoundTypes(lower, upper BoundType) error { + if lower == Unbounded || lower == Empty { + var zero T + r.Lower = zero + } + if upper == Unbounded || upper == Empty { + var zero T + r.Upper = zero + } + r.LowerType = lower + r.UpperType = upper + r.Valid = true + return nil +} diff --git a/pgtype/range_codec.go b/pgtype/range_codec.go new file mode 100644 index 00000000..8cfb3a63 --- /dev/null +++ b/pgtype/range_codec.go @@ -0,0 +1,379 @@ +package pgtype + +import ( + "database/sql/driver" + "fmt" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +// RangeValuer is a type that can be converted into a PostgreSQL range. +type RangeValuer interface { + // IsNull returns true if the value is SQL NULL. + IsNull() bool + + // BoundTypes returns the lower and upper bound types. + BoundTypes() (lower, upper BoundType) + + // Bounds returns the lower and upper range values. + Bounds() (lower, upper any) +} + +// RangeScanner is a type can be scanned from a PostgreSQL range. +type RangeScanner interface { + // ScanNull sets the value to SQL NULL. + ScanNull() error + + // ScanBounds returns values usable as a scan target. The returned values may not be scanned if the range is empty or + // the bound type is unbounded. + ScanBounds() (lowerTarget, upperTarget any) + + // SetBoundTypes sets the lower and upper bound types. ScanBounds will be called and the returned values scanned + // (if appropriate) before SetBoundTypes is called. If the bound types are unbounded or empty this method must + // also set the bound values. + SetBoundTypes(lower, upper BoundType) error +} + +// RangeCodec is a codec for any range type. +type RangeCodec struct { + ElementType *Type +} + +func (c *RangeCodec) FormatSupported(format int16) bool { + return c.ElementType.Codec.FormatSupported(format) +} + +func (c *RangeCodec) PreferredFormat() int16 { + if c.FormatSupported(BinaryFormatCode) { + return BinaryFormatCode + } + return TextFormatCode +} + +func (c *RangeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(RangeValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return &encodePlanRangeCodecRangeValuerToBinary{rc: c, m: m} + case TextFormatCode: + return &encodePlanRangeCodecRangeValuerToText{rc: c, m: m} + } + + return nil +} + +type encodePlanRangeCodecRangeValuerToBinary struct { + rc *RangeCodec + m *Map +} + +func (plan *encodePlanRangeCodecRangeValuerToBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + getter := value.(RangeValuer) + + if getter.IsNull() { + return nil, nil + } + + lowerType, upperType := getter.BoundTypes() + lower, upper := getter.Bounds() + + var rangeType byte + switch lowerType { + case Inclusive: + rangeType |= lowerInclusiveMask + case Unbounded: + rangeType |= lowerUnboundedMask + case Exclusive: + case Empty: + return append(buf, emptyMask), nil + default: + return nil, fmt.Errorf("unknown LowerType: %v", lowerType) + } + + switch upperType { + case Inclusive: + rangeType |= upperInclusiveMask + case Unbounded: + rangeType |= upperUnboundedMask + case Exclusive: + default: + return nil, fmt.Errorf("unknown UpperType: %v", upperType) + } + + buf = append(buf, rangeType) + + if lowerType != Unbounded { + if lower == nil { + return nil, fmt.Errorf("Lower cannot be NULL unless LowerType is Unbounded") + } + + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + lowerPlan := plan.m.PlanEncode(plan.rc.ElementType.OID, BinaryFormatCode, lower) + if lowerPlan == nil { + return nil, fmt.Errorf("cannot encode %v as element of range", lower) + } + + buf, err = lowerPlan.Encode(lower, buf) + if err != nil { + return nil, fmt.Errorf("failed to encode %v as element of range: %v", lower, err) + } + if buf == nil { + return nil, fmt.Errorf("Lower cannot be NULL unless LowerType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + if upperType != Unbounded { + if upper == nil { + return nil, fmt.Errorf("Upper cannot be NULL unless UpperType is Unbounded") + } + + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + upperPlan := plan.m.PlanEncode(plan.rc.ElementType.OID, BinaryFormatCode, upper) + if upperPlan == nil { + return nil, fmt.Errorf("cannot encode %v as element of range", upper) + } + + buf, err = upperPlan.Encode(upper, buf) + if err != nil { + return nil, fmt.Errorf("failed to encode %v as element of range: %v", upper, err) + } + if buf == nil { + return nil, fmt.Errorf("Upper cannot be NULL unless UpperType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + return buf, nil +} + +type encodePlanRangeCodecRangeValuerToText struct { + rc *RangeCodec + m *Map +} + +func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value any, buf []byte) (newBuf []byte, err error) { + getter := value.(RangeValuer) + + if getter.IsNull() { + return nil, nil + } + + lowerType, upperType := getter.BoundTypes() + lower, upper := getter.Bounds() + + switch lowerType { + case Exclusive, Unbounded: + buf = append(buf, '(') + case Inclusive: + buf = append(buf, '[') + case Empty: + return append(buf, "empty"...), nil + default: + return nil, fmt.Errorf("unknown lower bound type %v", lowerType) + } + + if lowerType != Unbounded { + if lower == nil { + return nil, fmt.Errorf("Lower cannot be NULL unless LowerType is Unbounded") + } + + lowerPlan := plan.m.PlanEncode(plan.rc.ElementType.OID, TextFormatCode, lower) + if lowerPlan == nil { + return nil, fmt.Errorf("cannot encode %v as element of range", lower) + } + + buf, err = lowerPlan.Encode(lower, buf) + if err != nil { + return nil, fmt.Errorf("failed to encode %v as element of range: %v", lower, err) + } + if buf == nil { + return nil, fmt.Errorf("Lower cannot be NULL unless LowerType is Unbounded") + } + } + + buf = append(buf, ',') + + if upperType != Unbounded { + if upper == nil { + return nil, fmt.Errorf("Upper cannot be NULL unless UpperType is Unbounded") + } + + upperPlan := plan.m.PlanEncode(plan.rc.ElementType.OID, TextFormatCode, upper) + if upperPlan == nil { + return nil, fmt.Errorf("cannot encode %v as element of range", upper) + } + + buf, err = upperPlan.Encode(upper, buf) + if err != nil { + return nil, fmt.Errorf("failed to encode %v as element of range: %v", upper, err) + } + if buf == nil { + return nil, fmt.Errorf("Upper cannot be NULL unless UpperType is Unbounded") + } + } + + switch upperType { + case Exclusive, Unbounded: + buf = append(buf, ')') + case Inclusive: + buf = append(buf, ']') + default: + return nil, fmt.Errorf("unknown upper bound type %v", upperType) + } + + return buf, nil +} + +func (c *RangeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case RangeScanner: + return &scanPlanBinaryRangeToRangeScanner{rc: c, m: m} + } + case TextFormatCode: + switch target.(type) { + case RangeScanner: + return &scanPlanTextRangeToRangeScanner{rc: c, m: m} + } + } + + return nil +} + +type scanPlanBinaryRangeToRangeScanner struct { + rc *RangeCodec + m *Map +} + +func (plan *scanPlanBinaryRangeToRangeScanner) Scan(src []byte, target any) error { + rangeScanner := (target).(RangeScanner) + + if src == nil { + return rangeScanner.ScanNull() + } + + ubr, err := parseUntypedBinaryRange(src) + if err != nil { + return err + } + + if ubr.LowerType == Empty { + return rangeScanner.SetBoundTypes(ubr.LowerType, ubr.UpperType) + } + + lowerTarget, upperTarget := rangeScanner.ScanBounds() + + if ubr.LowerType == Inclusive || ubr.LowerType == Exclusive { + lowerPlan := plan.m.PlanScan(plan.rc.ElementType.OID, BinaryFormatCode, lowerTarget) + if lowerPlan == nil { + return fmt.Errorf("cannot scan into %v from range element", lowerTarget) + } + + err = lowerPlan.Scan(ubr.Lower, lowerTarget) + if err != nil { + return fmt.Errorf("cannot scan into %v from range element: %v", lowerTarget, err) + } + } + + if ubr.UpperType == Inclusive || ubr.UpperType == Exclusive { + upperPlan := plan.m.PlanScan(plan.rc.ElementType.OID, BinaryFormatCode, upperTarget) + if upperPlan == nil { + return fmt.Errorf("cannot scan into %v from range element", upperTarget) + } + + err = upperPlan.Scan(ubr.Upper, upperTarget) + if err != nil { + return fmt.Errorf("cannot scan into %v from range element: %v", upperTarget, err) + } + } + + return rangeScanner.SetBoundTypes(ubr.LowerType, ubr.UpperType) +} + +type scanPlanTextRangeToRangeScanner struct { + rc *RangeCodec + m *Map +} + +func (plan *scanPlanTextRangeToRangeScanner) Scan(src []byte, target any) error { + rangeScanner := (target).(RangeScanner) + + if src == nil { + return rangeScanner.ScanNull() + } + + utr, err := parseUntypedTextRange(string(src)) + if err != nil { + return err + } + + if utr.LowerType == Empty { + return rangeScanner.SetBoundTypes(utr.LowerType, utr.UpperType) + } + + lowerTarget, upperTarget := rangeScanner.ScanBounds() + + if utr.LowerType == Inclusive || utr.LowerType == Exclusive { + lowerPlan := plan.m.PlanScan(plan.rc.ElementType.OID, TextFormatCode, lowerTarget) + if lowerPlan == nil { + return fmt.Errorf("cannot scan into %v from range element", lowerTarget) + } + + err = lowerPlan.Scan([]byte(utr.Lower), lowerTarget) + if err != nil { + return fmt.Errorf("cannot scan into %v from range element: %v", lowerTarget, err) + } + } + + if utr.UpperType == Inclusive || utr.UpperType == Exclusive { + upperPlan := plan.m.PlanScan(plan.rc.ElementType.OID, TextFormatCode, upperTarget) + if upperPlan == nil { + return fmt.Errorf("cannot scan into %v from range element", upperTarget) + } + + err = upperPlan.Scan([]byte(utr.Upper), upperTarget) + if err != nil { + return fmt.Errorf("cannot scan into %v from range element: %v", upperTarget, err) + } + } + + return rangeScanner.SetBoundTypes(utr.LowerType, utr.UpperType) +} + +func (c *RangeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + switch format { + case TextFormatCode: + return string(src), nil + case BinaryFormatCode: + buf := make([]byte, len(src)) + copy(buf, src) + return buf, nil + default: + return nil, fmt.Errorf("unknown format code %d", format) + } +} + +func (c *RangeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var r Range[any] + err := c.PlanScan(m, oid, format, &r).Scan(src, &r) + return r, err +} diff --git a/pgtype/range_codec_test.go b/pgtype/range_codec_test.go new file mode 100644 index 00000000..c0628747 --- /dev/null +++ b/pgtype/range_codec_test.go @@ -0,0 +1,163 @@ +package pgtype_test + +import ( + "context" + "testing" + + pgx "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/require" +) + +func TestRangeCodecTranscode(t *testing.T) { + skipCockroachDB(t, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)") + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int4range", []pgxtest.ValueRoundTripTest{ + { + pgtype.Range[pgtype.Int4]{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}, + new(pgtype.Range[pgtype.Int4]), + isExpectedEq(pgtype.Range[pgtype.Int4]{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}), + }, + { + pgtype.Range[pgtype.Int4]{ + LowerType: pgtype.Inclusive, + Lower: pgtype.Int4{Int32: 1, Valid: true}, + Upper: pgtype.Int4{Int32: 5, Valid: true}, + UpperType: pgtype.Exclusive, Valid: true, + }, + new(pgtype.Range[pgtype.Int4]), + isExpectedEq(pgtype.Range[pgtype.Int4]{ + LowerType: pgtype.Inclusive, + Lower: pgtype.Int4{Int32: 1, Valid: true}, + Upper: pgtype.Int4{Int32: 5, Valid: true}, + UpperType: pgtype.Exclusive, Valid: true, + }), + }, + {pgtype.Range[pgtype.Int4]{}, new(pgtype.Range[pgtype.Int4]), isExpectedEq(pgtype.Range[pgtype.Int4]{})}, + {nil, new(pgtype.Range[pgtype.Int4]), isExpectedEq(pgtype.Range[pgtype.Int4]{})}, + }) +} + +func TestRangeCodecTranscodeCompatibleRangeElementTypes(t *testing.T) { + ctr := defaultConnTestRunner + ctr.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)") + } + + pgxtest.RunValueRoundTripTests(context.Background(), t, ctr, nil, "numrange", []pgxtest.ValueRoundTripTest{ + { + pgtype.Range[pgtype.Float8]{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}, + new(pgtype.Range[pgtype.Float8]), + isExpectedEq(pgtype.Range[pgtype.Float8]{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}), + }, + { + pgtype.Range[pgtype.Float8]{ + LowerType: pgtype.Inclusive, + Lower: pgtype.Float8{Float64: 1, Valid: true}, + Upper: pgtype.Float8{Float64: 5, Valid: true}, + UpperType: pgtype.Exclusive, Valid: true, + }, + new(pgtype.Range[pgtype.Float8]), + isExpectedEq(pgtype.Range[pgtype.Float8]{ + LowerType: pgtype.Inclusive, + Lower: pgtype.Float8{Float64: 1, Valid: true}, + Upper: pgtype.Float8{Float64: 5, Valid: true}, + UpperType: pgtype.Exclusive, Valid: true, + }), + }, + {pgtype.Range[pgtype.Float8]{}, new(pgtype.Range[pgtype.Float8]), isExpectedEq(pgtype.Range[pgtype.Float8]{})}, + {nil, new(pgtype.Range[pgtype.Float8]), isExpectedEq(pgtype.Range[pgtype.Float8]{})}, + }) +} + +func TestRangeCodecScanRangeTwiceWithUnbounded(t *testing.T) { + skipCockroachDB(t, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)") + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + + var r pgtype.Range[pgtype.Int4] + + err := conn.QueryRow(context.Background(), `select '[1,5)'::int4range`).Scan(&r) + require.NoError(t, err) + + require.Equal( + t, + pgtype.Range[pgtype.Int4]{ + Lower: pgtype.Int4{Int32: 1, Valid: true}, + Upper: pgtype.Int4{Int32: 5, Valid: true}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + r, + ) + + err = conn.QueryRow(ctx, `select '[1,)'::int4range`).Scan(&r) + require.NoError(t, err) + + require.Equal( + t, + pgtype.Range[pgtype.Int4]{ + Lower: pgtype.Int4{Int32: 1, Valid: true}, + Upper: pgtype.Int4{}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Unbounded, + Valid: true, + }, + r, + ) + + err = conn.QueryRow(ctx, `select 'empty'::int4range`).Scan(&r) + require.NoError(t, err) + + require.Equal( + t, + pgtype.Range[pgtype.Int4]{ + Lower: pgtype.Int4{}, + Upper: pgtype.Int4{}, + LowerType: pgtype.Empty, + UpperType: pgtype.Empty, + Valid: true, + }, + r, + ) + }) +} + +func TestRangeCodecDecodeValue(t *testing.T) { + skipCockroachDB(t, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)") + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + + for _, tt := range []struct { + sql string + expected any + }{ + { + sql: `select '[1,5)'::int4range`, + expected: pgtype.Range[any]{ + Lower: int32(1), + Upper: int32(5), + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + }, + } { + t.Run(tt.sql, func(t *testing.T) { + rows, err := conn.Query(ctx, tt.sql) + require.NoError(t, err) + + for rows.Next() { + values, err := rows.Values() + require.NoError(t, err) + require.Len(t, values, 1) + require.Equal(t, tt.expected, values[0]) + } + + require.NoError(t, rows.Err()) + }) + } + }) +} diff --git a/pgtype/range_test.go b/pgtype/range_test.go new file mode 100644 index 00000000..1ee8d553 --- /dev/null +++ b/pgtype/range_test.go @@ -0,0 +1,177 @@ +package pgtype + +import ( + "bytes" + "testing" +) + +func TestParseUntypedTextRange(t *testing.T) { + tests := []struct { + src string + result untypedTextRange + err error + }{ + { + src: `[1,2)`, + result: untypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `[1,2]`, + result: untypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Inclusive}, + err: nil, + }, + { + src: `(1,3)`, + result: untypedTextRange{Lower: "1", Upper: "3", LowerType: Exclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: ` [1,2) `, + result: untypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `[ foo , bar )`, + result: untypedTextRange{Lower: " foo ", Upper: " bar ", LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `["foo","bar")`, + result: untypedTextRange{Lower: "foo", Upper: "bar", LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `["f""oo","b""ar")`, + result: untypedTextRange{Lower: `f"oo`, Upper: `b"ar`, LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `["f""oo","b""ar")`, + result: untypedTextRange{Lower: `f"oo`, Upper: `b"ar`, LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `["","bar")`, + result: untypedTextRange{Lower: ``, Upper: `bar`, LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `[f\"oo\,,b\\ar\))`, + result: untypedTextRange{Lower: `f"oo,`, Upper: `b\ar)`, LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `empty`, + result: untypedTextRange{Lower: "", Upper: "", LowerType: Empty, UpperType: Empty}, + err: nil, + }, + } + + for i, tt := range tests { + r, err := parseUntypedTextRange(tt.src) + if err != tt.err { + t.Errorf("%d. `%v`: expected err %v, got %v", i, tt.src, tt.err, err) + continue + } + + if r.LowerType != tt.result.LowerType { + t.Errorf("%d. `%v`: expected result lower type %v, got %v", i, tt.src, string(tt.result.LowerType), string(r.LowerType)) + } + + if r.UpperType != tt.result.UpperType { + t.Errorf("%d. `%v`: expected result upper type %v, got %v", i, tt.src, string(tt.result.UpperType), string(r.UpperType)) + } + + if r.Lower != tt.result.Lower { + t.Errorf("%d. `%v`: expected result lower %v, got %v", i, tt.src, tt.result.Lower, r.Lower) + } + + if r.Upper != tt.result.Upper { + t.Errorf("%d. `%v`: expected result upper %v, got %v", i, tt.src, tt.result.Upper, r.Upper) + } + } +} + +func TestParseUntypedBinaryRange(t *testing.T) { + tests := []struct { + src []byte + result untypedBinaryRange + err error + }{ + { + src: []byte{0, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, + result: untypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Exclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: []byte{1}, + result: untypedBinaryRange{Lower: nil, Upper: nil, LowerType: Empty, UpperType: Empty}, + err: nil, + }, + { + src: []byte{2, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, + result: untypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: []byte{4, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, + result: untypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Exclusive, UpperType: Inclusive}, + err: nil, + }, + { + src: []byte{6, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, + result: untypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Inclusive, UpperType: Inclusive}, + err: nil, + }, + { + src: []byte{8, 0, 0, 0, 2, 0, 5}, + result: untypedBinaryRange{Lower: nil, Upper: []byte{0, 5}, LowerType: Unbounded, UpperType: Exclusive}, + err: nil, + }, + { + src: []byte{12, 0, 0, 0, 2, 0, 5}, + result: untypedBinaryRange{Lower: nil, Upper: []byte{0, 5}, LowerType: Unbounded, UpperType: Inclusive}, + err: nil, + }, + { + src: []byte{16, 0, 0, 0, 2, 0, 4}, + result: untypedBinaryRange{Lower: []byte{0, 4}, Upper: nil, LowerType: Exclusive, UpperType: Unbounded}, + err: nil, + }, + { + src: []byte{18, 0, 0, 0, 2, 0, 4}, + result: untypedBinaryRange{Lower: []byte{0, 4}, Upper: nil, LowerType: Inclusive, UpperType: Unbounded}, + err: nil, + }, + { + src: []byte{24}, + result: untypedBinaryRange{Lower: nil, Upper: nil, LowerType: Unbounded, UpperType: Unbounded}, + err: nil, + }, + } + + for i, tt := range tests { + r, err := parseUntypedBinaryRange(tt.src) + if err != tt.err { + t.Errorf("%d. `%v`: expected err %v, got %v", i, tt.src, tt.err, err) + continue + } + + if r.LowerType != tt.result.LowerType { + t.Errorf("%d. `%v`: expected result lower type %v, got %v", i, tt.src, string(tt.result.LowerType), string(r.LowerType)) + } + + if r.UpperType != tt.result.UpperType { + t.Errorf("%d. `%v`: expected result upper type %v, got %v", i, tt.src, string(tt.result.UpperType), string(r.UpperType)) + } + + if bytes.Compare(r.Lower, tt.result.Lower) != 0 { + t.Errorf("%d. `%v`: expected result lower %v, got %v", i, tt.src, tt.result.Lower, r.Lower) + } + + if bytes.Compare(r.Upper, tt.result.Upper) != 0 { + t.Errorf("%d. `%v`: expected result upper %v, got %v", i, tt.src, tt.result.Upper, r.Upper) + } + } +} diff --git a/pgtype/record_codec.go b/pgtype/record_codec.go new file mode 100644 index 00000000..b3b16604 --- /dev/null +++ b/pgtype/record_codec.go @@ -0,0 +1,125 @@ +package pgtype + +import ( + "database/sql/driver" + "fmt" +) + +// ArrayGetter is a type that can be converted into a PostgreSQL array. + +// RecordCodec is a codec for the generic PostgreSQL record type such as is created with the "row" function. Record can +// only decode the binary format. The text format output format from PostgreSQL does not include type information and +// is therefore impossible to decode. Encoding is impossible because PostgreSQL does not support input of generic +// records. +type RecordCodec struct{} + +func (RecordCodec) FormatSupported(format int16) bool { + return format == BinaryFormatCode +} + +func (RecordCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (RecordCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + return nil +} + +func (RecordCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + if format == BinaryFormatCode { + switch target.(type) { + case CompositeIndexScanner: + return &scanPlanBinaryRecordToCompositeIndexScanner{m: m} + } + } + + return nil +} + +type scanPlanBinaryRecordToCompositeIndexScanner struct { + m *Map +} + +func (plan *scanPlanBinaryRecordToCompositeIndexScanner) Scan(src []byte, target any) error { + targetScanner := (target).(CompositeIndexScanner) + + if src == nil { + return targetScanner.ScanNull() + } + + scanner := NewCompositeBinaryScanner(plan.m, src) + for i := 0; scanner.Next(); i++ { + fieldTarget := targetScanner.ScanIndex(i) + if fieldTarget != nil { + fieldPlan := plan.m.PlanScan(scanner.OID(), BinaryFormatCode, fieldTarget) + if fieldPlan == nil { + return fmt.Errorf("unable to scan OID %d in binary format into %v", scanner.OID(), fieldTarget) + } + + err := fieldPlan.Scan(scanner.Bytes(), fieldTarget) + if err != nil { + return err + } + } + } + + if err := scanner.Err(); err != nil { + return err + } + + return nil +} + +func (RecordCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + switch format { + case TextFormatCode: + return string(src), nil + case BinaryFormatCode: + buf := make([]byte, len(src)) + copy(buf, src) + return buf, nil + default: + return nil, fmt.Errorf("unknown format code %d", format) + } +} + +func (RecordCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + switch format { + case TextFormatCode: + return string(src), nil + case BinaryFormatCode: + scanner := NewCompositeBinaryScanner(m, src) + values := make([]any, scanner.FieldCount()) + for i := 0; scanner.Next(); i++ { + var v any + fieldPlan := m.PlanScan(scanner.OID(), BinaryFormatCode, &v) + if fieldPlan == nil { + return nil, fmt.Errorf("unable to scan OID %d in binary format into %v", scanner.OID(), v) + } + + err := fieldPlan.Scan(scanner.Bytes(), &v) + if err != nil { + return nil, err + } + + values[i] = v + } + + if err := scanner.Err(); err != nil { + return nil, err + } + + return values, nil + default: + return nil, fmt.Errorf("unknown format code %d", format) + } + +} diff --git a/pgtype/record_codec_test.go b/pgtype/record_codec_test.go new file mode 100644 index 00000000..2189f99c --- /dev/null +++ b/pgtype/record_codec_test.go @@ -0,0 +1,73 @@ +package pgtype_test + +import ( + "context" + "testing" + + pgx "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/stretchr/testify/require" +) + +func TestRecordCodec(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var a string + var b int32 + err := conn.QueryRow(ctx, `select row('foo'::text, 42::int4)`).Scan(pgtype.CompositeFields{&a, &b}) + require.NoError(t, err) + + require.Equal(t, "foo", a) + require.Equal(t, int32(42), b) + }) +} + +func TestRecordCodecDecodeValue(t *testing.T) { + skipCockroachDB(t, "Server converts row int4 to int8") + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + for _, tt := range []struct { + sql string + expected any + }{ + { + sql: `select row()`, + expected: []any{}, + }, + { + sql: `select row('foo'::text, 42::int4)`, + expected: []any{"foo", int32(42)}, + }, + { + sql: `select row(100.0::float4, 1.09::float4)`, + expected: []any{float32(100), float32(1.09)}, + }, + { + sql: `select row('foo'::text, array[1, 2, null, 4]::int4[], 42::int4)`, + expected: []any{"foo", []any{int32(1), int32(2), nil, int32(4)}, int32(42)}, + }, + { + sql: `select row(null)`, + expected: []any{nil}, + }, + { + sql: `select null::record`, + expected: nil, + }, + } { + t.Run(tt.sql, func(t *testing.T) { + rows, err := conn.Query(context.Background(), tt.sql) + require.NoError(t, err) + defer rows.Close() + + for rows.Next() { + values, err := rows.Values() + require.NoError(t, err) + require.Len(t, values, 1) + require.Equal(t, tt.expected, values[0]) + } + + require.NoError(t, rows.Err()) + }) + } + }) +} diff --git a/pgtype/register_default_pg_types.go b/pgtype/register_default_pg_types.go new file mode 100644 index 00000000..be1ca4a1 --- /dev/null +++ b/pgtype/register_default_pg_types.go @@ -0,0 +1,35 @@ +//go:build !nopgxregisterdefaulttypes + +package pgtype + +func registerDefaultPgTypeVariants[T any](m *Map, name string) { + arrayName := "_" + name + + var value T + m.RegisterDefaultPgType(value, name) // T + m.RegisterDefaultPgType(&value, name) // *T + + var sliceT []T + m.RegisterDefaultPgType(sliceT, arrayName) // []T + m.RegisterDefaultPgType(&sliceT, arrayName) // *[]T + + var slicePtrT []*T + m.RegisterDefaultPgType(slicePtrT, arrayName) // []*T + m.RegisterDefaultPgType(&slicePtrT, arrayName) // *[]*T + + var arrayOfT Array[T] + m.RegisterDefaultPgType(arrayOfT, arrayName) // Array[T] + m.RegisterDefaultPgType(&arrayOfT, arrayName) // *Array[T] + + var arrayOfPtrT Array[*T] + m.RegisterDefaultPgType(arrayOfPtrT, arrayName) // Array[*T] + m.RegisterDefaultPgType(&arrayOfPtrT, arrayName) // *Array[*T] + + var flatArrayOfT FlatArray[T] + m.RegisterDefaultPgType(flatArrayOfT, arrayName) // FlatArray[T] + m.RegisterDefaultPgType(&flatArrayOfT, arrayName) // *FlatArray[T] + + var flatArrayOfPtrT FlatArray[*T] + m.RegisterDefaultPgType(flatArrayOfPtrT, arrayName) // FlatArray[*T] + m.RegisterDefaultPgType(&flatArrayOfPtrT, arrayName) // *FlatArray[*T] +} diff --git a/pgtype/register_default_pg_types_disabled.go b/pgtype/register_default_pg_types_disabled.go new file mode 100644 index 00000000..56fe7c22 --- /dev/null +++ b/pgtype/register_default_pg_types_disabled.go @@ -0,0 +1,6 @@ +//go:build nopgxregisterdefaulttypes + +package pgtype + +func registerDefaultPgTypeVariants[T any](m *Map, name string) { +} diff --git a/pgtype/text.go b/pgtype/text.go new file mode 100644 index 00000000..021ee331 --- /dev/null +++ b/pgtype/text.go @@ -0,0 +1,223 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/json" + "fmt" +) + +type TextScanner interface { + ScanText(v Text) error +} + +type TextValuer interface { + TextValue() (Text, error) +} + +type Text struct { + String string + Valid bool +} + +func (t *Text) ScanText(v Text) error { + *t = v + return nil +} + +func (t Text) TextValue() (Text, error) { + return t, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Text) Scan(src any) error { + if src == nil { + *dst = Text{} + return nil + } + + switch src := src.(type) { + case string: + *dst = Text{String: src, Valid: true} + return nil + case []byte: + *dst = Text{String: string(src), Valid: true} + return nil + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Text) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + return src.String, nil +} + +func (src Text) MarshalJSON() ([]byte, error) { + if !src.Valid { + return []byte("null"), nil + } + + return json.Marshal(src.String) +} + +func (dst *Text) UnmarshalJSON(b []byte) error { + var s *string + err := json.Unmarshal(b, &s) + if err != nil { + return err + } + + if s == nil { + *dst = Text{} + } else { + *dst = Text{String: *s, Valid: true} + } + + return nil +} + +type TextCodec struct{} + +func (TextCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (TextCodec) PreferredFormat() int16 { + return TextFormatCode +} + +func (TextCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch format { + case TextFormatCode, BinaryFormatCode: + switch value.(type) { + case string: + return encodePlanTextCodecString{} + case []byte: + return encodePlanTextCodecByteSlice{} + case TextValuer: + return encodePlanTextCodecTextValuer{} + } + } + + return nil +} + +type encodePlanTextCodecString struct{} + +func (encodePlanTextCodecString) Encode(value any, buf []byte) (newBuf []byte, err error) { + s := value.(string) + buf = append(buf, s...) + return buf, nil +} + +type encodePlanTextCodecByteSlice struct{} + +func (encodePlanTextCodecByteSlice) Encode(value any, buf []byte) (newBuf []byte, err error) { + s := value.([]byte) + buf = append(buf, s...) + return buf, nil +} + +type encodePlanTextCodecStringer struct{} + +func (encodePlanTextCodecStringer) Encode(value any, buf []byte) (newBuf []byte, err error) { + s := value.(fmt.Stringer) + buf = append(buf, s.String()...) + return buf, nil +} + +type encodePlanTextCodecTextValuer struct{} + +func (encodePlanTextCodecTextValuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + text, err := value.(TextValuer).TextValue() + if err != nil { + return nil, err + } + + if !text.Valid { + return nil, nil + } + + buf = append(buf, text.String...) + return buf, nil +} + +func (TextCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + + switch format { + case TextFormatCode, BinaryFormatCode: + switch target.(type) { + case *string: + return scanPlanTextAnyToString{} + case *[]byte: + return scanPlanAnyToNewByteSlice{} + case BytesScanner: + return scanPlanAnyToByteScanner{} + case TextScanner: + return scanPlanTextAnyToTextScanner{} + } + } + + return nil +} + +func (c TextCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return c.DecodeValue(m, oid, format, src) +} + +func (c TextCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + return string(src), nil +} + +type scanPlanTextAnyToString struct{} + +func (scanPlanTextAnyToString) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + p := (dst).(*string) + *p = string(src) + + return nil +} + +type scanPlanAnyToNewByteSlice struct{} + +func (scanPlanAnyToNewByteSlice) Scan(src []byte, dst any) error { + p := (dst).(*[]byte) + if src == nil { + *p = nil + } else { + *p = make([]byte, len(src)) + copy(*p, src) + } + + return nil +} + +type scanPlanAnyToByteScanner struct{} + +func (scanPlanAnyToByteScanner) Scan(src []byte, dst any) error { + p := (dst).(BytesScanner) + return p.ScanBytes(src) +} + +type scanPlanTextAnyToTextScanner struct{} + +func (scanPlanTextAnyToTextScanner) Scan(src []byte, dst any) error { + scanner := (dst).(TextScanner) + + if src == nil { + return scanner.ScanText(Text{}) + } + + return scanner.ScanText(Text{String: string(src), Valid: true}) +} diff --git a/pgtype/text_format_only_codec.go b/pgtype/text_format_only_codec.go new file mode 100644 index 00000000..d5e4cdb3 --- /dev/null +++ b/pgtype/text_format_only_codec.go @@ -0,0 +1,13 @@ +package pgtype + +type TextFormatOnlyCodec struct { + Codec +} + +func (c *TextFormatOnlyCodec) FormatSupported(format int16) bool { + return format == TextFormatCode && c.Codec.FormatSupported(format) +} + +func (TextFormatOnlyCodec) PreferredFormat() int16 { + return TextFormatCode +} diff --git a/pgtype/text_test.go b/pgtype/text_test.go new file mode 100644 index 00000000..eb5d005e --- /dev/null +++ b/pgtype/text_test.go @@ -0,0 +1,178 @@ +package pgtype_test + +import ( + "context" + "testing" + + pgx "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/require" +) + +type someFmtStringer struct{} + +func (someFmtStringer) String() string { + return "some fmt.Stringer" +} + +func TestTextCodec(t *testing.T) { + for _, pgTypeName := range []string{"text", "varchar"} { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, pgTypeName, []pgxtest.ValueRoundTripTest{ + { + pgtype.Text{String: "", Valid: true}, + new(pgtype.Text), + isExpectedEq(pgtype.Text{String: "", Valid: true}), + }, + { + pgtype.Text{String: "foo", Valid: true}, + new(pgtype.Text), + isExpectedEq(pgtype.Text{String: "foo", Valid: true}), + }, + {nil, new(pgtype.Text), isExpectedEq(pgtype.Text{})}, + {"foo", new(string), isExpectedEq("foo")}, + {someFmtStringer{}, new(string), isExpectedEq("some fmt.Stringer")}, + }) + } +} + +// name is PostgreSQL's special 63-byte data type, used for identifiers like table names. The pg_class.relname column +// is a good example of where the name data type is used. +// +// TextCodec does not do length checking. Inputting a longer name into PostgreSQL will result in silent truncation to +// 63 bytes. +// +// Length checking would be possible with a Codec specialized for "name" but it would be perfect because a +// custom-compiled PostgreSQL could have set NAMEDATALEN to a different value rather than the default 63. +// +// So this is simply a smoke test of the name type. +func TestTextCodecName(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "name", []pgxtest.ValueRoundTripTest{ + { + pgtype.Text{String: "", Valid: true}, + new(pgtype.Text), + isExpectedEq(pgtype.Text{String: "", Valid: true}), + }, + { + pgtype.Text{String: "foo", Valid: true}, + new(pgtype.Text), + isExpectedEq(pgtype.Text{String: "foo", Valid: true}), + }, + {nil, new(pgtype.Text), isExpectedEq(pgtype.Text{})}, + {"foo", new(string), isExpectedEq("foo")}, + }) +} + +// Test fixed length char types like char(3) +func TestTextCodecBPChar(t *testing.T) { + skipCockroachDB(t, "Server does not properly handle bpchar with multi-byte character") + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "char(3)", []pgxtest.ValueRoundTripTest{ + { + pgtype.Text{String: "a ", Valid: true}, + new(pgtype.Text), + isExpectedEq(pgtype.Text{String: "a ", Valid: true}), + }, + {nil, new(pgtype.Text), isExpectedEq(pgtype.Text{})}, + {" ", new(string), isExpectedEq(" ")}, + {"", new(string), isExpectedEq(" ")}, + {" 嗨 ", new(string), isExpectedEq(" 嗨 ")}, + }) +} + +// ACLItem is used for PostgreSQL's aclitem data type. A sample aclitem +// might look like this: +// +// postgres=arwdDxt/postgres +// +// Note, however, that because the user/role name part of an aclitem is +// an identifier, it follows all the usual formatting rules for SQL +// identifiers: if it contains spaces and other special characters, +// it should appear in double-quotes: +// +// postgres=arwdDxt/"role with spaces" +// +// It only supports the text format. +func TestTextCodecACLItem(t *testing.T) { + ctr := defaultConnTestRunner + ctr.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "Server does not support type aclitem") + } + + pgxtest.RunValueRoundTripTests(context.Background(), t, ctr, nil, "aclitem", []pgxtest.ValueRoundTripTest{ + { + pgtype.Text{String: "postgres=arwdDxt/postgres", Valid: true}, + new(pgtype.Text), + isExpectedEq(pgtype.Text{String: "postgres=arwdDxt/postgres", Valid: true}), + }, + {pgtype.Text{}, new(pgtype.Text), isExpectedEq(pgtype.Text{})}, + {nil, new(pgtype.Text), isExpectedEq(pgtype.Text{})}, + }) +} + +func TestTextCodecACLItemRoleWithSpecialCharacters(t *testing.T) { + ctr := defaultConnTestRunner + ctr.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "Server does not support type aclitem") + + // The tricky test user, below, has to actually exist so that it can be used in a test + // of aclitem formatting. It turns out aclitems cannot contain non-existing users/roles. + roleWithSpecialCharacters := ` tricky, ' } " \ test user ` + + commandTag, err := conn.Exec(ctx, `select * from pg_roles where rolname = $1`, roleWithSpecialCharacters) + require.NoError(t, err) + + if commandTag.RowsAffected() == 0 { + t.Skipf("Role with special characters does not exist.") + } + } + + pgxtest.RunValueRoundTripTests(context.Background(), t, ctr, nil, "aclitem", []pgxtest.ValueRoundTripTest{ + { + pgtype.Text{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Valid: true}, + new(pgtype.Text), + isExpectedEq(pgtype.Text{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Valid: true}), + }, + }) +} + +func TestTextMarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Text + result string + }{ + {source: pgtype.Text{String: ""}, result: "null"}, + {source: pgtype.Text{String: "a", Valid: true}, result: "\"a\""}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + +func TestTextUnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Text + }{ + {source: "null", result: pgtype.Text{String: ""}}, + {source: "\"a\"", result: pgtype.Text{String: "a", Valid: true}}, + } + for i, tt := range successfulTests { + var r pgtype.Text + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/pgtype/tid.go b/pgtype/tid.go new file mode 100644 index 00000000..cb4a9ec4 --- /dev/null +++ b/pgtype/tid.go @@ -0,0 +1,241 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "strconv" + "strings" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type TIDScanner interface { + ScanTID(v TID) error +} + +type TIDValuer interface { + TIDValue() (TID, error) +} + +// TID is PostgreSQL's Tuple Identifier type. +// +// When one does +// +// select ctid, * from some_table; +// +// it is the data type of the ctid hidden system column. +// +// It is currently implemented as a pair unsigned two byte integers. +// Its conversion functions can be found in src/backend/utils/adt/tid.c +// in the PostgreSQL sources. +type TID struct { + BlockNumber uint32 + OffsetNumber uint16 + Valid bool +} + +func (b *TID) ScanTID(v TID) error { + *b = v + return nil +} + +func (b TID) TIDValue() (TID, error) { + return b, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *TID) Scan(src any) error { + if src == nil { + *dst = TID{} + return nil + } + + switch src := src.(type) { + case string: + return scanPlanTextAnyToTIDScanner{}.Scan([]byte(src), dst) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src TID) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + + buf, err := TIDCodec{}.PlanEncode(nil, 0, TextFormatCode, src).Encode(src, nil) + if err != nil { + return nil, err + } + return string(buf), err +} + +type TIDCodec struct{} + +func (TIDCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (TIDCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (TIDCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(TIDValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return encodePlanTIDCodecBinary{} + case TextFormatCode: + return encodePlanTIDCodecText{} + } + + return nil +} + +type encodePlanTIDCodecBinary struct{} + +func (encodePlanTIDCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + tid, err := value.(TIDValuer).TIDValue() + if err != nil { + return nil, err + } + + if !tid.Valid { + return nil, nil + } + + buf = pgio.AppendUint32(buf, tid.BlockNumber) + buf = pgio.AppendUint16(buf, tid.OffsetNumber) + return buf, nil +} + +type encodePlanTIDCodecText struct{} + +func (encodePlanTIDCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { + tid, err := value.(TIDValuer).TIDValue() + if err != nil { + return nil, err + } + + if !tid.Valid { + return nil, nil + } + + buf = append(buf, fmt.Sprintf(`(%d,%d)`, tid.BlockNumber, tid.OffsetNumber)...) + return buf, nil +} + +func (TIDCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case TIDScanner: + return scanPlanBinaryTIDToTIDScanner{} + case TextScanner: + return scanPlanBinaryTIDToTextScanner{} + } + case TextFormatCode: + switch target.(type) { + case TIDScanner: + return scanPlanTextAnyToTIDScanner{} + } + } + + return nil +} + +type scanPlanBinaryTIDToTIDScanner struct{} + +func (scanPlanBinaryTIDToTIDScanner) Scan(src []byte, dst any) error { + scanner := (dst).(TIDScanner) + + if src == nil { + return scanner.ScanTID(TID{}) + } + + if len(src) != 6 { + return fmt.Errorf("invalid length for tid: %v", len(src)) + } + + return scanner.ScanTID(TID{ + BlockNumber: binary.BigEndian.Uint32(src), + OffsetNumber: binary.BigEndian.Uint16(src[4:]), + Valid: true, + }) +} + +type scanPlanBinaryTIDToTextScanner struct{} + +func (scanPlanBinaryTIDToTextScanner) Scan(src []byte, dst any) error { + scanner := (dst).(TextScanner) + + if src == nil { + return scanner.ScanText(Text{}) + } + + if len(src) != 6 { + return fmt.Errorf("invalid length for tid: %v", len(src)) + } + + blockNumber := binary.BigEndian.Uint32(src) + offsetNumber := binary.BigEndian.Uint16(src[4:]) + + return scanner.ScanText(Text{ + String: fmt.Sprintf(`(%d,%d)`, blockNumber, offsetNumber), + Valid: true, + }) +} + +type scanPlanTextAnyToTIDScanner struct{} + +func (scanPlanTextAnyToTIDScanner) Scan(src []byte, dst any) error { + scanner := (dst).(TIDScanner) + + if src == nil { + return scanner.ScanTID(TID{}) + } + + if len(src) < 5 { + return fmt.Errorf("invalid length for tid: %v", len(src)) + } + + parts := strings.SplitN(string(src[1:len(src)-1]), ",", 2) + if len(parts) < 2 { + return fmt.Errorf("invalid format for tid") + } + + blockNumber, err := strconv.ParseUint(parts[0], 10, 32) + if err != nil { + return err + } + + offsetNumber, err := strconv.ParseUint(parts[1], 10, 16) + if err != nil { + return err + } + + return scanner.ScanTID(TID{BlockNumber: uint32(blockNumber), OffsetNumber: uint16(offsetNumber), Valid: true}) +} + +func (c TIDCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) +} + +func (c TIDCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var tid TID + err := codecScan(c, m, oid, format, src, &tid) + if err != nil { + return nil, err + } + return tid, nil +} diff --git a/pgtype/tid_test.go b/pgtype/tid_test.go new file mode 100644 index 00000000..3e7a1a50 --- /dev/null +++ b/pgtype/tid_test.go @@ -0,0 +1,38 @@ +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" +) + +func TestTIDCodec(t *testing.T) { + skipCockroachDB(t, "Server does not support type tid") + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "tid", []pgxtest.ValueRoundTripTest{ + { + pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Valid: true}, + new(pgtype.TID), + isExpectedEq(pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Valid: true}), + }, + { + pgtype.TID{BlockNumber: 4294967295, OffsetNumber: 65535, Valid: true}, + new(pgtype.TID), + isExpectedEq(pgtype.TID{BlockNumber: 4294967295, OffsetNumber: 65535, Valid: true}), + }, + { + pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Valid: true}, + new(string), + isExpectedEq("(42,43)"), + }, + { + pgtype.TID{BlockNumber: 4294967295, OffsetNumber: 65535, Valid: true}, + new(string), + isExpectedEq("(4294967295,65535)"), + }, + {pgtype.TID{}, new(pgtype.TID), isExpectedEq(pgtype.TID{})}, + {nil, new(pgtype.TID), isExpectedEq(pgtype.TID{})}, + }) +} diff --git a/pgtype/time.go b/pgtype/time.go new file mode 100644 index 00000000..2eb6ace2 --- /dev/null +++ b/pgtype/time.go @@ -0,0 +1,233 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "strconv" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type TimeScanner interface { + ScanTime(v Time) error +} + +type TimeValuer interface { + TimeValue() (Time, error) +} + +// Time represents the PostgreSQL time type. The PostgreSQL time is a time of day without time zone. +// +// Time is represented as the number of microseconds since midnight in the same way that PostgreSQL does. Other time +// and date types in pgtype can use time.Time as the underlying representation. However, pgtype.Time type cannot due +// to needing to handle 24:00:00. time.Time converts that to 00:00:00 on the following day. +type Time struct { + Microseconds int64 // Number of microseconds since midnight + Valid bool +} + +func (t *Time) ScanTime(v Time) error { + *t = v + return nil +} + +func (t Time) TimeValue() (Time, error) { + return t, nil +} + +// Scan implements the database/sql Scanner interface. +func (t *Time) Scan(src any) error { + if src == nil { + *t = Time{} + return nil + } + + switch src := src.(type) { + case string: + return scanPlanTextAnyToTimeScanner{}.Scan([]byte(src), t) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (t Time) Value() (driver.Value, error) { + if !t.Valid { + return nil, nil + } + + buf, err := TimeCodec{}.PlanEncode(nil, 0, TextFormatCode, t).Encode(t, nil) + if err != nil { + return nil, err + } + return string(buf), err +} + +type TimeCodec struct{} + +func (TimeCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (TimeCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (TimeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(TimeValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return encodePlanTimeCodecBinary{} + case TextFormatCode: + return encodePlanTimeCodecText{} + } + + return nil +} + +type encodePlanTimeCodecBinary struct{} + +func (encodePlanTimeCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + t, err := value.(TimeValuer).TimeValue() + if err != nil { + return nil, err + } + + if !t.Valid { + return nil, nil + } + + return pgio.AppendInt64(buf, t.Microseconds), nil +} + +type encodePlanTimeCodecText struct{} + +func (encodePlanTimeCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { + t, err := value.(TimeValuer).TimeValue() + if err != nil { + return nil, err + } + + if !t.Valid { + return nil, nil + } + + usec := t.Microseconds + hours := usec / microsecondsPerHour + usec -= hours * microsecondsPerHour + minutes := usec / microsecondsPerMinute + usec -= minutes * microsecondsPerMinute + seconds := usec / microsecondsPerSecond + usec -= seconds * microsecondsPerSecond + + s := fmt.Sprintf("%02d:%02d:%02d.%06d", hours, minutes, seconds, usec) + + return append(buf, s...), nil +} + +func (TimeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case TimeScanner: + return scanPlanBinaryTimeToTimeScanner{} + } + case TextFormatCode: + switch target.(type) { + case TimeScanner: + return scanPlanTextAnyToTimeScanner{} + } + } + + return nil +} + +type scanPlanBinaryTimeToTimeScanner struct{} + +func (scanPlanBinaryTimeToTimeScanner) Scan(src []byte, dst any) error { + scanner := (dst).(TimeScanner) + + if src == nil { + return scanner.ScanTime(Time{}) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for time: %v", len(src)) + } + + usec := int64(binary.BigEndian.Uint64(src)) + + return scanner.ScanTime(Time{Microseconds: usec, Valid: true}) +} + +type scanPlanTextAnyToTimeScanner struct{} + +func (scanPlanTextAnyToTimeScanner) Scan(src []byte, dst any) error { + scanner := (dst).(TimeScanner) + + if src == nil { + return scanner.ScanTime(Time{}) + } + + s := string(src) + + if len(s) < 8 { + return fmt.Errorf("cannot decode %v into Time", s) + } + + hours, err := strconv.ParseInt(s[0:2], 10, 64) + if err != nil { + return fmt.Errorf("cannot decode %v into Time", s) + } + usec := hours * microsecondsPerHour + + minutes, err := strconv.ParseInt(s[3:5], 10, 64) + if err != nil { + return fmt.Errorf("cannot decode %v into Time", s) + } + usec += minutes * microsecondsPerMinute + + seconds, err := strconv.ParseInt(s[6:8], 10, 64) + if err != nil { + return fmt.Errorf("cannot decode %v into Time", s) + } + usec += seconds * microsecondsPerSecond + + if len(s) > 9 { + fraction := s[9:] + n, err := strconv.ParseInt(fraction, 10, 64) + if err != nil { + return fmt.Errorf("cannot decode %v into Time", s) + } + + for i := len(fraction); i < 6; i++ { + n *= 10 + } + + usec += n + } + + return scanner.ScanTime(Time{Microseconds: usec, Valid: true}) +} + +func (c TimeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) +} + +func (c TimeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var t Time + err := codecScan(c, m, oid, format, src, &t) + if err != nil { + return nil, err + } + return t, nil +} diff --git a/pgtype/time_test.go b/pgtype/time_test.go new file mode 100644 index 00000000..01bcee0f --- /dev/null +++ b/pgtype/time_test.go @@ -0,0 +1,47 @@ +package pgtype_test + +import ( + "context" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" +) + +func TestTimeCodec(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "time", []pgxtest.ValueRoundTripTest{ + { + pgtype.Time{Microseconds: 0, Valid: true}, + new(pgtype.Time), + isExpectedEq(pgtype.Time{Microseconds: 0, Valid: true}), + }, + { + pgtype.Time{Microseconds: 1, Valid: true}, + new(pgtype.Time), + isExpectedEq(pgtype.Time{Microseconds: 1, Valid: true}), + }, + { + pgtype.Time{Microseconds: 86399999999, Valid: true}, + new(pgtype.Time), + isExpectedEq(pgtype.Time{Microseconds: 86399999999, Valid: true}), + }, + { + pgtype.Time{Microseconds: 86400000000, Valid: true}, + new(pgtype.Time), + isExpectedEq(pgtype.Time{Microseconds: 86400000000, Valid: true}), + }, + { + time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), + new(pgtype.Time), + isExpectedEq(pgtype.Time{Microseconds: 0, Valid: true}), + }, + { + pgtype.Time{Microseconds: 0, Valid: true}, + new(time.Time), + isExpectedEq(time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)), + }, + {pgtype.Time{}, new(pgtype.Time), isExpectedEq(pgtype.Time{})}, + {nil, new(pgtype.Time), isExpectedEq(pgtype.Time{})}, + }) +} diff --git a/pgtype/timestamp.go b/pgtype/timestamp.go new file mode 100644 index 00000000..9f3de2c5 --- /dev/null +++ b/pgtype/timestamp.go @@ -0,0 +1,295 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "strings" + "time" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +const pgTimestampFormat = "2006-01-02 15:04:05.999999999" + +type TimestampScanner interface { + ScanTimestamp(v Timestamp) error +} + +type TimestampValuer interface { + TimestampValue() (Timestamp, error) +} + +// Timestamp represents the PostgreSQL timestamp type. +type Timestamp struct { + Time time.Time // Time zone will be ignored when encoding to PostgreSQL. + InfinityModifier InfinityModifier + Valid bool +} + +func (ts *Timestamp) ScanTimestamp(v Timestamp) error { + *ts = v + return nil +} + +func (ts Timestamp) TimestampValue() (Timestamp, error) { + return ts, nil +} + +// Scan implements the database/sql Scanner interface. +func (ts *Timestamp) Scan(src any) error { + if src == nil { + *ts = Timestamp{} + return nil + } + + switch src := src.(type) { + case string: + return scanPlanTextTimestampToTimestampScanner{}.Scan([]byte(src), ts) + case time.Time: + *ts = Timestamp{Time: src, Valid: true} + return nil + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (ts Timestamp) Value() (driver.Value, error) { + if !ts.Valid { + return nil, nil + } + + if ts.InfinityModifier != Finite { + return ts.InfinityModifier.String(), nil + } + return ts.Time, nil +} + +type TimestampCodec struct{} + +func (TimestampCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (TimestampCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (TimestampCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(TimestampValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return encodePlanTimestampCodecBinary{} + case TextFormatCode: + return encodePlanTimestampCodecText{} + } + + return nil +} + +type encodePlanTimestampCodecBinary struct{} + +func (encodePlanTimestampCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + ts, err := value.(TimestampValuer).TimestampValue() + if err != nil { + return nil, err + } + + if !ts.Valid { + return nil, nil + } + + var microsecSinceY2K int64 + switch ts.InfinityModifier { + case Finite: + t := discardTimeZone(ts.Time) + microsecSinceUnixEpoch := t.Unix()*1000000 + int64(t.Nanosecond())/1000 + microsecSinceY2K = microsecSinceUnixEpoch - microsecFromUnixEpochToY2K + case Infinity: + microsecSinceY2K = infinityMicrosecondOffset + case NegativeInfinity: + microsecSinceY2K = negativeInfinityMicrosecondOffset + } + + buf = pgio.AppendInt64(buf, microsecSinceY2K) + + return buf, nil +} + +type encodePlanTimestampCodecText struct{} + +func (encodePlanTimestampCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { + ts, err := value.(TimestampValuer).TimestampValue() + if err != nil { + return nil, err + } + + if !ts.Valid { + return nil, nil + } + + var s string + + switch ts.InfinityModifier { + case Finite: + t := discardTimeZone(ts.Time) + + // Year 0000 is 1 BC + bc := false + if year := t.Year(); year <= 0 { + year = -year + 1 + t = time.Date(year, t.Month(), t.Day(), t.Hour(), t.Minute(), t.Second(), t.Nanosecond(), time.UTC) + bc = true + } + + s = t.Truncate(time.Microsecond).Format(pgTimestampFormat) + + if bc { + s = s + " BC" + } + case Infinity: + s = "infinity" + case NegativeInfinity: + s = "-infinity" + } + + buf = append(buf, s...) + + return buf, nil +} + +func discardTimeZone(t time.Time) time.Time { + if t.Location() != time.UTC { + return time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), t.Minute(), t.Second(), t.Nanosecond(), time.UTC) + } + + return t +} + +func (TimestampCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case TimestampScanner: + return scanPlanBinaryTimestampToTimestampScanner{} + } + case TextFormatCode: + switch target.(type) { + case TimestampScanner: + return scanPlanTextTimestampToTimestampScanner{} + } + } + + return nil +} + +type scanPlanBinaryTimestampToTimestampScanner struct{} + +func (scanPlanBinaryTimestampToTimestampScanner) Scan(src []byte, dst any) error { + scanner := (dst).(TimestampScanner) + + if src == nil { + return scanner.ScanTimestamp(Timestamp{}) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for timestamp: %v", len(src)) + } + + var ts Timestamp + microsecSinceY2K := int64(binary.BigEndian.Uint64(src)) + + switch microsecSinceY2K { + case infinityMicrosecondOffset: + ts = Timestamp{Valid: true, InfinityModifier: Infinity} + case negativeInfinityMicrosecondOffset: + ts = Timestamp{Valid: true, InfinityModifier: -Infinity} + default: + tim := time.Unix( + microsecFromUnixEpochToY2K/1000000+microsecSinceY2K/1000000, + (microsecFromUnixEpochToY2K%1000000*1000)+(microsecSinceY2K%1000000*1000), + ).UTC() + ts = Timestamp{Time: tim, Valid: true} + } + + return scanner.ScanTimestamp(ts) +} + +type scanPlanTextTimestampToTimestampScanner struct{} + +func (scanPlanTextTimestampToTimestampScanner) Scan(src []byte, dst any) error { + scanner := (dst).(TimestampScanner) + + if src == nil { + return scanner.ScanTimestamp(Timestamp{}) + } + + var ts Timestamp + sbuf := string(src) + switch sbuf { + case "infinity": + ts = Timestamp{Valid: true, InfinityModifier: Infinity} + case "-infinity": + ts = Timestamp{Valid: true, InfinityModifier: -Infinity} + default: + bc := false + if strings.HasSuffix(sbuf, " BC") { + sbuf = sbuf[:len(sbuf)-3] + bc = true + } + tim, err := time.Parse(pgTimestampFormat, sbuf) + if err != nil { + return err + } + + if bc { + year := -tim.Year() + 1 + tim = time.Date(year, tim.Month(), tim.Day(), tim.Hour(), tim.Minute(), tim.Second(), tim.Nanosecond(), tim.Location()) + } + + ts = Timestamp{Time: tim, Valid: true} + } + + return scanner.ScanTimestamp(ts) +} + +func (c TimestampCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + var ts Timestamp + err := codecScan(c, m, oid, format, src, &ts) + if err != nil { + return nil, err + } + + if ts.InfinityModifier != Finite { + return ts.InfinityModifier.String(), nil + } + + return ts.Time, nil +} + +func (c TimestampCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var ts Timestamp + err := codecScan(c, m, oid, format, src, &ts) + if err != nil { + return nil, err + } + + if ts.InfinityModifier != Finite { + return ts.InfinityModifier, nil + } + + return ts.Time, nil +} diff --git a/pgtype/timestamp_test.go b/pgtype/timestamp_test.go new file mode 100644 index 00000000..849f55f6 --- /dev/null +++ b/pgtype/timestamp_test.go @@ -0,0 +1,64 @@ +package pgtype_test + +import ( + "context" + "testing" + "time" + + pgx "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/require" +) + +func TestTimestampCodec(t *testing.T) { + skipCockroachDB(t, "Server does not support infinite timestamps (see https://github.com/cockroachdb/cockroach/issues/41564)") + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "timestamp", []pgxtest.ValueRoundTripTest{ + {time.Date(-100, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(-100, 1, 1, 0, 0, 0, 0, time.UTC))}, + {time.Date(-1, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(-1, 1, 1, 0, 0, 0, 0, time.UTC))}, + {time.Date(0, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(0, 1, 1, 0, 0, 0, 0, time.UTC))}, + {time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC))}, + + {time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC))}, + {time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC))}, + {time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC))}, + {time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC))}, + {time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC))}, + {time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC))}, + + // Nanosecond truncation + {time.Date(2020, 1, 1, 0, 0, 0, 999999999, time.UTC), new(time.Time), isExpectedEqTime(time.Date(2020, 1, 1, 0, 0, 0, 999999000, time.UTC))}, + {time.Date(2020, 1, 1, 0, 0, 0, 999999001, time.UTC), new(time.Time), isExpectedEqTime(time.Date(2020, 1, 1, 0, 0, 0, 999999000, time.UTC))}, + + {pgtype.Timestamp{InfinityModifier: pgtype.Infinity, Valid: true}, new(pgtype.Timestamp), isExpectedEq(pgtype.Timestamp{InfinityModifier: pgtype.Infinity, Valid: true})}, + {pgtype.Timestamp{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, new(pgtype.Timestamp), isExpectedEq(pgtype.Timestamp{InfinityModifier: pgtype.NegativeInfinity, Valid: true})}, + {pgtype.Timestamp{}, new(pgtype.Timestamp), isExpectedEq(pgtype.Timestamp{})}, + {nil, new(*time.Time), isExpectedEq((*time.Time)(nil))}, + }) +} + +// https://github.com/jackc/pgx/v4/pgtype/pull/128 +func TestTimestampTranscodeBigTimeBinary(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + in := &pgtype.Timestamp{Time: time.Date(294276, 12, 31, 23, 59, 59, 999999000, time.UTC), Valid: true} + var out pgtype.Timestamp + + err := conn.QueryRow(ctx, "select $1::timestamp", in).Scan(&out) + if err != nil { + t.Fatal(err) + } + + require.Equal(t, in.Valid, out.Valid) + require.Truef(t, in.Time.Equal(out.Time), "expected %v got %v", in.Time, out.Time) + }) +} + +// https://github.com/jackc/pgtype/issues/74 +func TestTimestampCodecDecodeTextInvalid(t *testing.T) { + c := &pgtype.TimestampCodec{} + var ts pgtype.Timestamp + plan := c.PlanScan(nil, pgtype.TimestampOID, pgtype.TextFormatCode, &ts) + err := plan.Scan([]byte(`eeeee`), &ts) + require.Error(t, err) +} diff --git a/pgtype/timestamptz.go b/pgtype/timestamptz.go new file mode 100644 index 00000000..f568fe30 --- /dev/null +++ b/pgtype/timestamptz.go @@ -0,0 +1,355 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +const pgTimestamptzHourFormat = "2006-01-02 15:04:05.999999999Z07" +const pgTimestamptzMinuteFormat = "2006-01-02 15:04:05.999999999Z07:00" +const pgTimestamptzSecondFormat = "2006-01-02 15:04:05.999999999Z07:00:00" +const microsecFromUnixEpochToY2K = 946684800 * 1000000 + +const ( + negativeInfinityMicrosecondOffset = -9223372036854775808 + infinityMicrosecondOffset = 9223372036854775807 +) + +type TimestamptzScanner interface { + ScanTimestamptz(v Timestamptz) error +} + +type TimestamptzValuer interface { + TimestamptzValue() (Timestamptz, error) +} + +// Timestamptz represents the PostgreSQL timestamptz type. +type Timestamptz struct { + Time time.Time + InfinityModifier InfinityModifier + Valid bool +} + +func (tstz *Timestamptz) ScanTimestamptz(v Timestamptz) error { + *tstz = v + return nil +} + +func (tstz Timestamptz) TimestamptzValue() (Timestamptz, error) { + return tstz, nil +} + +// Scan implements the database/sql Scanner interface. +func (tstz *Timestamptz) Scan(src any) error { + if src == nil { + *tstz = Timestamptz{} + return nil + } + + switch src := src.(type) { + case string: + return scanPlanTextTimestamptzToTimestamptzScanner{}.Scan([]byte(src), tstz) + case time.Time: + *tstz = Timestamptz{Time: src, Valid: true} + return nil + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (tstz Timestamptz) Value() (driver.Value, error) { + if !tstz.Valid { + return nil, nil + } + + if tstz.InfinityModifier != Finite { + return tstz.InfinityModifier.String(), nil + } + return tstz.Time, nil +} + +func (tstz Timestamptz) MarshalJSON() ([]byte, error) { + if !tstz.Valid { + return []byte("null"), nil + } + + var s string + + switch tstz.InfinityModifier { + case Finite: + s = tstz.Time.Format(time.RFC3339Nano) + case Infinity: + s = "infinity" + case NegativeInfinity: + s = "-infinity" + } + + return json.Marshal(s) +} + +func (tstz *Timestamptz) UnmarshalJSON(b []byte) error { + var s *string + err := json.Unmarshal(b, &s) + if err != nil { + return err + } + + if s == nil { + *tstz = Timestamptz{} + return nil + } + + switch *s { + case "infinity": + *tstz = Timestamptz{Valid: true, InfinityModifier: Infinity} + case "-infinity": + *tstz = Timestamptz{Valid: true, InfinityModifier: -Infinity} + default: + // PostgreSQL uses ISO 8601 for to_json function and casting from a string to timestamptz + tim, err := time.Parse(time.RFC3339Nano, *s) + if err != nil { + return err + } + + *tstz = Timestamptz{Time: tim, Valid: true} + } + + return nil +} + +type TimestamptzCodec struct{} + +func (TimestamptzCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (TimestamptzCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (TimestamptzCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(TimestamptzValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return encodePlanTimestamptzCodecBinary{} + case TextFormatCode: + return encodePlanTimestamptzCodecText{} + } + + return nil +} + +type encodePlanTimestamptzCodecBinary struct{} + +func (encodePlanTimestamptzCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + ts, err := value.(TimestamptzValuer).TimestamptzValue() + if err != nil { + return nil, err + } + + if !ts.Valid { + return nil, nil + } + + var microsecSinceY2K int64 + switch ts.InfinityModifier { + case Finite: + microsecSinceUnixEpoch := ts.Time.Unix()*1000000 + int64(ts.Time.Nanosecond())/1000 + microsecSinceY2K = microsecSinceUnixEpoch - microsecFromUnixEpochToY2K + case Infinity: + microsecSinceY2K = infinityMicrosecondOffset + case NegativeInfinity: + microsecSinceY2K = negativeInfinityMicrosecondOffset + } + + buf = pgio.AppendInt64(buf, microsecSinceY2K) + + return buf, nil +} + +type encodePlanTimestamptzCodecText struct{} + +func (encodePlanTimestamptzCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { + ts, err := value.(TimestamptzValuer).TimestamptzValue() + if err != nil { + return nil, err + } + + if !ts.Valid { + return nil, nil + } + + var s string + + switch ts.InfinityModifier { + case Finite: + + t := ts.Time.UTC().Truncate(time.Microsecond) + + // Year 0000 is 1 BC + bc := false + if year := t.Year(); year <= 0 { + year = -year + 1 + t = time.Date(year, t.Month(), t.Day(), t.Hour(), t.Minute(), t.Second(), t.Nanosecond(), time.UTC) + bc = true + } + + s = t.Format(pgTimestamptzSecondFormat) + + if bc { + s = s + " BC" + } + case Infinity: + s = "infinity" + case NegativeInfinity: + s = "-infinity" + } + + buf = append(buf, s...) + + return buf, nil +} + +func (TimestamptzCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case TimestamptzScanner: + return scanPlanBinaryTimestamptzToTimestamptzScanner{} + } + case TextFormatCode: + switch target.(type) { + case TimestamptzScanner: + return scanPlanTextTimestamptzToTimestamptzScanner{} + } + } + + return nil +} + +type scanPlanBinaryTimestamptzToTimestamptzScanner struct{} + +func (scanPlanBinaryTimestamptzToTimestamptzScanner) Scan(src []byte, dst any) error { + scanner := (dst).(TimestamptzScanner) + + if src == nil { + return scanner.ScanTimestamptz(Timestamptz{}) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for timestamptz: %v", len(src)) + } + + var tstz Timestamptz + microsecSinceY2K := int64(binary.BigEndian.Uint64(src)) + + switch microsecSinceY2K { + case infinityMicrosecondOffset: + tstz = Timestamptz{Valid: true, InfinityModifier: Infinity} + case negativeInfinityMicrosecondOffset: + tstz = Timestamptz{Valid: true, InfinityModifier: -Infinity} + default: + tim := time.Unix( + microsecFromUnixEpochToY2K/1000000+microsecSinceY2K/1000000, + (microsecFromUnixEpochToY2K%1000000*1000)+(microsecSinceY2K%1000000*1000), + ) + tstz = Timestamptz{Time: tim, Valid: true} + } + + return scanner.ScanTimestamptz(tstz) +} + +type scanPlanTextTimestamptzToTimestamptzScanner struct{} + +func (scanPlanTextTimestamptzToTimestamptzScanner) Scan(src []byte, dst any) error { + scanner := (dst).(TimestamptzScanner) + + if src == nil { + return scanner.ScanTimestamptz(Timestamptz{}) + } + + var tstz Timestamptz + sbuf := string(src) + switch sbuf { + case "infinity": + tstz = Timestamptz{Valid: true, InfinityModifier: Infinity} + case "-infinity": + tstz = Timestamptz{Valid: true, InfinityModifier: -Infinity} + default: + bc := false + if strings.HasSuffix(sbuf, " BC") { + sbuf = sbuf[:len(sbuf)-3] + bc = true + } + + var format string + if len(sbuf) >= 9 && (sbuf[len(sbuf)-9] == '-' || sbuf[len(sbuf)-9] == '+') { + format = pgTimestamptzSecondFormat + } else if len(sbuf) >= 6 && (sbuf[len(sbuf)-6] == '-' || sbuf[len(sbuf)-6] == '+') { + format = pgTimestamptzMinuteFormat + } else { + format = pgTimestamptzHourFormat + } + + tim, err := time.Parse(format, sbuf) + if err != nil { + return err + } + + if bc { + year := -tim.Year() + 1 + tim = time.Date(year, tim.Month(), tim.Day(), tim.Hour(), tim.Minute(), tim.Second(), tim.Nanosecond(), tim.Location()) + } + + tstz = Timestamptz{Time: tim, Valid: true} + } + + return scanner.ScanTimestamptz(tstz) +} + +func (c TimestamptzCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + var tstz Timestamptz + err := codecScan(c, m, oid, format, src, &tstz) + if err != nil { + return nil, err + } + + if tstz.InfinityModifier != Finite { + return tstz.InfinityModifier.String(), nil + } + + return tstz.Time, nil +} + +func (c TimestamptzCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var tstz Timestamptz + err := codecScan(c, m, oid, format, src, &tstz) + if err != nil { + return nil, err + } + + if tstz.InfinityModifier != Finite { + return tstz.InfinityModifier, nil + } + + return tstz.Time, nil +} diff --git a/pgtype/timestamptz_test.go b/pgtype/timestamptz_test.go new file mode 100644 index 00000000..0486ecdb --- /dev/null +++ b/pgtype/timestamptz_test.go @@ -0,0 +1,111 @@ +package pgtype_test + +import ( + "context" + "testing" + "time" + + pgx "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/require" +) + +func TestTimestamptzCodec(t *testing.T) { + skipCockroachDB(t, "Server does not support infinite timestamps (see https://github.com/cockroachdb/cockroach/issues/41564)") + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "timestamptz", []pgxtest.ValueRoundTripTest{ + {time.Date(-100, 1, 1, 0, 0, 0, 0, time.Local), new(time.Time), isExpectedEqTime(time.Date(-100, 1, 1, 0, 0, 0, 0, time.Local))}, + {time.Date(-1, 1, 1, 0, 0, 0, 0, time.Local), new(time.Time), isExpectedEqTime(time.Date(-1, 1, 1, 0, 0, 0, 0, time.Local))}, + {time.Date(0, 1, 1, 0, 0, 0, 0, time.Local), new(time.Time), isExpectedEqTime(time.Date(0, 1, 1, 0, 0, 0, 0, time.Local))}, + {time.Date(1, 1, 1, 0, 0, 0, 0, time.Local), new(time.Time), isExpectedEqTime(time.Date(1, 1, 1, 0, 0, 0, 0, time.Local))}, + + {time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), new(time.Time), isExpectedEqTime(time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local))}, + {time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), new(time.Time), isExpectedEqTime(time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local))}, + {time.Date(1999, 12, 31, 0, 0, 0, 0, time.Local), new(time.Time), isExpectedEqTime(time.Date(1999, 12, 31, 0, 0, 0, 0, time.Local))}, + {time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), new(time.Time), isExpectedEqTime(time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local))}, + {time.Date(2000, 1, 2, 0, 0, 0, 0, time.Local), new(time.Time), isExpectedEqTime(time.Date(2000, 1, 2, 0, 0, 0, 0, time.Local))}, + {time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), new(time.Time), isExpectedEqTime(time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local))}, + + // Nanosecond truncation + {time.Date(2020, 1, 1, 0, 0, 0, 999999999, time.Local), new(time.Time), isExpectedEqTime(time.Date(2020, 1, 1, 0, 0, 0, 999999000, time.Local))}, + {time.Date(2020, 1, 1, 0, 0, 0, 999999001, time.Local), new(time.Time), isExpectedEqTime(time.Date(2020, 1, 1, 0, 0, 0, 999999000, time.Local))}, + + {pgtype.Timestamptz{InfinityModifier: pgtype.Infinity, Valid: true}, new(pgtype.Timestamptz), isExpectedEq(pgtype.Timestamptz{InfinityModifier: pgtype.Infinity, Valid: true})}, + {pgtype.Timestamptz{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, new(pgtype.Timestamptz), isExpectedEq(pgtype.Timestamptz{InfinityModifier: pgtype.NegativeInfinity, Valid: true})}, + {pgtype.Timestamptz{}, new(pgtype.Timestamptz), isExpectedEq(pgtype.Timestamptz{})}, + {nil, new(*time.Time), isExpectedEq((*time.Time)(nil))}, + }) +} + +// https://github.com/jackc/pgx/v4/pgtype/pull/128 +func TestTimestamptzTranscodeBigTimeBinary(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + in := &pgtype.Timestamptz{Time: time.Date(294276, 12, 31, 23, 59, 59, 999999000, time.UTC), Valid: true} + var out pgtype.Timestamptz + + err := conn.QueryRow(ctx, "select $1::timestamptz", in).Scan(&out) + if err != nil { + t.Fatal(err) + } + + require.Equal(t, in.Valid, out.Valid) + require.Truef(t, in.Time.Equal(out.Time), "expected %v got %v", in.Time, out.Time) + }) +} + +// https://github.com/jackc/pgtype/issues/74 +func TestTimestamptzDecodeTextInvalid(t *testing.T) { + c := &pgtype.TimestamptzCodec{} + var tstz pgtype.Timestamptz + plan := c.PlanScan(nil, pgtype.TimestamptzOID, pgtype.TextFormatCode, &tstz) + err := plan.Scan([]byte(`eeeee`), &tstz) + require.Error(t, err) +} + +func TestTimestamptzMarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Timestamptz + result string + }{ + {source: pgtype.Timestamptz{}, result: "null"}, + {source: pgtype.Timestamptz{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.FixedZone("", -6*60*60)), Valid: true}, result: "\"2012-03-29T10:05:45-06:00\""}, + {source: pgtype.Timestamptz{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.FixedZone("", -6*60*60)), Valid: true}, result: "\"2012-03-29T10:05:45.555-06:00\""}, + {source: pgtype.Timestamptz{InfinityModifier: pgtype.Infinity, Valid: true}, result: "\"infinity\""}, + {source: pgtype.Timestamptz{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, result: "\"-infinity\""}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + +func TestTimestamptzUnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Timestamptz + }{ + {source: "null", result: pgtype.Timestamptz{}}, + {source: "\"2012-03-29T10:05:45-06:00\"", result: pgtype.Timestamptz{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.FixedZone("", -6*60*60)), Valid: true}}, + {source: "\"2012-03-29T10:05:45.555-06:00\"", result: pgtype.Timestamptz{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.FixedZone("", -6*60*60)), Valid: true}}, + {source: "\"infinity\"", result: pgtype.Timestamptz{InfinityModifier: pgtype.Infinity, Valid: true}}, + {source: "\"-infinity\"", result: pgtype.Timestamptz{InfinityModifier: pgtype.NegativeInfinity, Valid: true}}, + } + for i, tt := range successfulTests { + var r pgtype.Timestamptz + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !r.Time.Equal(tt.result.Time) || r.Valid != tt.result.Valid || r.InfinityModifier != tt.result.InfinityModifier { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/pgtype/uint32.go b/pgtype/uint32.go new file mode 100644 index 00000000..098c516c --- /dev/null +++ b/pgtype/uint32.go @@ -0,0 +1,303 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type Uint32Scanner interface { + ScanUint32(v Uint32) error +} + +type Uint32Valuer interface { + Uint32Value() (Uint32, error) +} + +// Uint32 is the core type that is used to represent PostgreSQL types such as OID, CID, and XID. +type Uint32 struct { + Uint32 uint32 + Valid bool +} + +func (n *Uint32) ScanUint32(v Uint32) error { + *n = v + return nil +} + +func (n Uint32) Uint32Value() (Uint32, error) { + return n, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Uint32) Scan(src any) error { + if src == nil { + *dst = Uint32{} + return nil + } + + var n int64 + + switch src := src.(type) { + case int64: + n = src + case string: + un, err := strconv.ParseUint(src, 10, 32) + if err != nil { + return err + } + n = int64(un) + default: + return fmt.Errorf("cannot scan %T", src) + } + + if n < 0 { + return fmt.Errorf("%d is less than the minimum value for Uint32", n) + } + if n > math.MaxUint32 { + return fmt.Errorf("%d is greater than maximum value for Uint32", n) + } + + *dst = Uint32{Uint32: uint32(n), Valid: true} + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (src Uint32) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + return int64(src.Uint32), nil +} + +type Uint32Codec struct{} + +func (Uint32Codec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (Uint32Codec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (Uint32Codec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch format { + case BinaryFormatCode: + switch value.(type) { + case uint32: + return encodePlanUint32CodecBinaryUint32{} + case Uint32Valuer: + return encodePlanUint32CodecBinaryUint32Valuer{} + case Int64Valuer: + return encodePlanUint32CodecBinaryInt64Valuer{} + } + case TextFormatCode: + switch value.(type) { + case uint32: + return encodePlanUint32CodecTextUint32{} + case Int64Valuer: + return encodePlanUint32CodecTextInt64Valuer{} + } + } + + return nil +} + +type encodePlanUint32CodecBinaryUint32 struct{} + +func (encodePlanUint32CodecBinaryUint32) Encode(value any, buf []byte) (newBuf []byte, err error) { + v := value.(uint32) + return pgio.AppendUint32(buf, v), nil +} + +type encodePlanUint32CodecBinaryUint32Valuer struct{} + +func (encodePlanUint32CodecBinaryUint32Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + v, err := value.(Uint32Valuer).Uint32Value() + if err != nil { + return nil, err + } + + if !v.Valid { + return nil, nil + } + + return pgio.AppendUint32(buf, v.Uint32), nil +} + +type encodePlanUint32CodecBinaryInt64Valuer struct{} + +func (encodePlanUint32CodecBinaryInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + v, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !v.Valid { + return nil, nil + } + + if v.Int64 < 0 { + return nil, fmt.Errorf("%d is less than minimum value for uint32", v.Int64) + } + if v.Int64 > math.MaxUint32 { + return nil, fmt.Errorf("%d is greater than maximum value for uint32", v.Int64) + } + + return pgio.AppendUint32(buf, uint32(v.Int64)), nil +} + +type encodePlanUint32CodecTextUint32 struct{} + +func (encodePlanUint32CodecTextUint32) Encode(value any, buf []byte) (newBuf []byte, err error) { + v := value.(uint32) + return append(buf, strconv.FormatUint(uint64(v), 10)...), nil +} + +type encodePlanUint32CodecTextUint32Valuer struct{} + +func (encodePlanUint32CodecTextUint32Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + v, err := value.(Uint32Valuer).Uint32Value() + if err != nil { + return nil, err + } + + if !v.Valid { + return nil, nil + } + + return append(buf, strconv.FormatUint(uint64(v.Uint32), 10)...), nil +} + +type encodePlanUint32CodecTextInt64Valuer struct{} + +func (encodePlanUint32CodecTextInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + v, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !v.Valid { + return nil, nil + } + + if v.Int64 < 0 { + return nil, fmt.Errorf("%d is less than minimum value for uint32", v.Int64) + } + if v.Int64 > math.MaxUint32 { + return nil, fmt.Errorf("%d is greater than maximum value for uint32", v.Int64) + } + + return append(buf, strconv.FormatInt(v.Int64, 10)...), nil +} + +func (Uint32Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case *uint32: + return scanPlanBinaryUint32ToUint32{} + case Uint32Scanner: + return scanPlanBinaryUint32ToUint32Scanner{} + } + case TextFormatCode: + switch target.(type) { + case *uint32: + return scanPlanTextAnyToUint32{} + case Uint32Scanner: + return scanPlanTextAnyToUint32Scanner{} + } + } + + return nil +} + +func (c Uint32Codec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + var n uint32 + err := codecScan(c, m, oid, format, src, &n) + if err != nil { + return nil, err + } + return int64(n), nil +} + +func (c Uint32Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var n uint32 + err := codecScan(c, m, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil +} + +type scanPlanBinaryUint32ToUint32 struct{} + +func (scanPlanBinaryUint32ToUint32) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for uint32: %v", len(src)) + } + + p := (dst).(*uint32) + *p = binary.BigEndian.Uint32(src) + + return nil +} + +type scanPlanBinaryUint32ToUint32Scanner struct{} + +func (scanPlanBinaryUint32ToUint32Scanner) Scan(src []byte, dst any) error { + s, ok := (dst).(Uint32Scanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanUint32(Uint32{}) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for uint32: %v", len(src)) + } + + n := binary.BigEndian.Uint32(src) + + return s.ScanUint32(Uint32{Uint32: n, Valid: true}) +} + +type scanPlanTextAnyToUint32Scanner struct{} + +func (scanPlanTextAnyToUint32Scanner) Scan(src []byte, dst any) error { + s, ok := (dst).(Uint32Scanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanUint32(Uint32{}) + } + + n, err := strconv.ParseUint(string(src), 10, 32) + if err != nil { + return err + } + + return s.ScanUint32(Uint32{Uint32: uint32(n), Valid: true}) +} diff --git a/pgtype/uint32_test.go b/pgtype/uint32_test.go new file mode 100644 index 00000000..842de643 --- /dev/null +++ b/pgtype/uint32_test.go @@ -0,0 +1,21 @@ +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" +) + +func TestUint32Codec(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "oid", []pgxtest.ValueRoundTripTest{ + { + pgtype.Uint32{Uint32: pgtype.TextOID, Valid: true}, + new(pgtype.Uint32), + isExpectedEq(pgtype.Uint32{Uint32: pgtype.TextOID, Valid: true}), + }, + {pgtype.Uint32{}, new(pgtype.Uint32), isExpectedEq(pgtype.Uint32{})}, + {nil, new(pgtype.Uint32), isExpectedEq(pgtype.Uint32{})}, + }) +} diff --git a/pgtype/uuid.go b/pgtype/uuid.go new file mode 100644 index 00000000..96a4c32f --- /dev/null +++ b/pgtype/uuid.go @@ -0,0 +1,269 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "encoding/hex" + "fmt" +) + +type UUIDScanner interface { + ScanUUID(v UUID) error +} + +type UUIDValuer interface { + UUIDValue() (UUID, error) +} + +type UUID struct { + Bytes [16]byte + Valid bool +} + +func (b *UUID) ScanUUID(v UUID) error { + *b = v + return nil +} + +func (b UUID) UUIDValue() (UUID, error) { + return b, nil +} + +// parseUUID converts a string UUID in standard form to a byte array. +func parseUUID(src string) (dst [16]byte, err error) { + switch len(src) { + case 36: + src = src[0:8] + src[9:13] + src[14:18] + src[19:23] + src[24:] + case 32: + // dashes already stripped, assume valid + default: + // assume invalid. + return dst, fmt.Errorf("cannot parse UUID %v", src) + } + + buf, err := hex.DecodeString(src) + if err != nil { + return dst, err + } + + copy(dst[:], buf) + return dst, err +} + +// encodeUUID converts a uuid byte array to UUID standard string form. +func encodeUUID(src [16]byte) string { + return fmt.Sprintf("%x-%x-%x-%x-%x", src[0:4], src[4:6], src[6:8], src[8:10], src[10:16]) +} + +// Scan implements the database/sql Scanner interface. +func (dst *UUID) Scan(src any) error { + if src == nil { + *dst = UUID{} + return nil + } + + switch src := src.(type) { + case string: + buf, err := parseUUID(src) + if err != nil { + return err + } + *dst = UUID{Bytes: buf, Valid: true} + return nil + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src UUID) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + + return encodeUUID(src.Bytes), nil +} + +func (src UUID) MarshalJSON() ([]byte, error) { + if !src.Valid { + return []byte("null"), nil + } + + var buff bytes.Buffer + buff.WriteByte('"') + buff.WriteString(encodeUUID(src.Bytes)) + buff.WriteByte('"') + return buff.Bytes(), nil +} + +func (dst *UUID) UnmarshalJSON(src []byte) error { + if bytes.Compare(src, []byte("null")) == 0 { + *dst = UUID{} + return nil + } + if len(src) != 38 { + return fmt.Errorf("invalid length for UUID: %v", len(src)) + } + buf, err := parseUUID(string(src[1 : len(src)-1])) + if err != nil { + return err + } + *dst = UUID{Bytes: buf, Valid: true} + return nil +} + +type UUIDCodec struct{} + +func (UUIDCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (UUIDCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (UUIDCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(UUIDValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return encodePlanUUIDCodecBinaryUUIDValuer{} + case TextFormatCode: + return encodePlanUUIDCodecTextUUIDValuer{} + } + + return nil +} + +type encodePlanUUIDCodecBinaryUUIDValuer struct{} + +func (encodePlanUUIDCodecBinaryUUIDValuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + uuid, err := value.(UUIDValuer).UUIDValue() + if err != nil { + return nil, err + } + + if !uuid.Valid { + return nil, nil + } + + return append(buf, uuid.Bytes[:]...), nil +} + +type encodePlanUUIDCodecTextUUIDValuer struct{} + +func (encodePlanUUIDCodecTextUUIDValuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + uuid, err := value.(UUIDValuer).UUIDValue() + if err != nil { + return nil, err + } + + if !uuid.Valid { + return nil, nil + } + + return append(buf, encodeUUID(uuid.Bytes)...), nil +} + +func (UUIDCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case UUIDScanner: + return scanPlanBinaryUUIDToUUIDScanner{} + case TextScanner: + return scanPlanBinaryUUIDToTextScanner{} + } + case TextFormatCode: + switch target.(type) { + case UUIDScanner: + return scanPlanTextAnyToUUIDScanner{} + } + } + + return nil +} + +type scanPlanBinaryUUIDToUUIDScanner struct{} + +func (scanPlanBinaryUUIDToUUIDScanner) Scan(src []byte, dst any) error { + scanner := (dst).(UUIDScanner) + + if src == nil { + return scanner.ScanUUID(UUID{}) + } + + if len(src) != 16 { + return fmt.Errorf("invalid length for UUID: %v", len(src)) + } + + uuid := UUID{Valid: true} + copy(uuid.Bytes[:], src) + + return scanner.ScanUUID(uuid) +} + +type scanPlanBinaryUUIDToTextScanner struct{} + +func (scanPlanBinaryUUIDToTextScanner) Scan(src []byte, dst any) error { + scanner := (dst).(TextScanner) + + if src == nil { + return scanner.ScanText(Text{}) + } + + if len(src) != 16 { + return fmt.Errorf("invalid length for UUID: %v", len(src)) + } + + var buf [16]byte + copy(buf[:], src) + + return scanner.ScanText(Text{String: encodeUUID(buf), Valid: true}) +} + +type scanPlanTextAnyToUUIDScanner struct{} + +func (scanPlanTextAnyToUUIDScanner) Scan(src []byte, dst any) error { + scanner := (dst).(UUIDScanner) + + if src == nil { + return scanner.ScanUUID(UUID{}) + } + + buf, err := parseUUID(string(src)) + if err != nil { + return err + } + + return scanner.ScanUUID(UUID{Bytes: buf, Valid: true}) +} + +func (c UUIDCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + var uuid UUID + err := codecScan(c, m, oid, format, src, &uuid) + if err != nil { + return nil, err + } + + return encodeUUID(uuid.Bytes), nil +} + +func (c UUIDCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var uuid UUID + err := codecScan(c, m, oid, format, src, &uuid) + if err != nil { + return nil, err + } + return uuid.Bytes, nil +} diff --git a/pgtype/uuid_test.go b/pgtype/uuid_test.go new file mode 100644 index 00000000..2dc258b1 --- /dev/null +++ b/pgtype/uuid_test.go @@ -0,0 +1,132 @@ +package pgtype_test + +import ( + "context" + "reflect" + "testing" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/require" +) + +func TestUUIDCodec(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "uuid", []pgxtest.ValueRoundTripTest{ + { + pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + new(pgtype.UUID), + isExpectedEq(pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}), + }, + { + "00010203-0405-0607-0809-0a0b0c0d0e0f", + new(pgtype.UUID), + isExpectedEq(pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}), + }, + { + "000102030405060708090a0b0c0d0e0f", + new(pgtype.UUID), + isExpectedEq(pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}), + }, + { + pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + new(string), + isExpectedEq("00010203-0405-0607-0809-0a0b0c0d0e0f"), + }, + {pgtype.UUID{}, new([]byte), isExpectedEqBytes([]byte(nil))}, + {pgtype.UUID{}, new(pgtype.UUID), isExpectedEq(pgtype.UUID{})}, + {nil, new(pgtype.UUID), isExpectedEq(pgtype.UUID{})}, + }) + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "uuid", []pgxtest.ValueRoundTripTest{ + { + [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + new(pgtype.UUID), + isExpectedEq(pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}), + }, + { + []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + new(pgtype.UUID), + isExpectedEq(pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}), + }, + }) +} + +func TestUUID_MarshalJSON(t *testing.T) { + tests := []struct { + name string + src pgtype.UUID + want []byte + }{ + { + name: "first", + src: pgtype.UUID{ + Bytes: [16]byte{29, 72, 90, 122, 109, 24, 69, 153, 140, 108, 52, 66, 86, 22, 136, 122}, + Valid: true, + }, + want: []byte(`"1d485a7a-6d18-4599-8c6c-34425616887a"`), + }, + { + name: "third", + src: pgtype.UUID{ + Bytes: [16]byte{}, + }, + want: []byte("null"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.src.MarshalJSON() + require.NoError(t, err) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("MarshalJSON() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUUID_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + want *pgtype.UUID + src []byte + wantErr bool + }{ + { + name: "first", + want: &pgtype.UUID{ + Bytes: [16]byte{29, 72, 90, 122, 109, 24, 69, 153, 140, 108, 52, 66, 86, 22, 136, 122}, + Valid: true, + }, + src: []byte(`"1d485a7a-6d18-4599-8c6c-34425616887a"`), + wantErr: false, + }, + { + name: "second", + want: &pgtype.UUID{ + Bytes: [16]byte{}, + }, + src: []byte("null"), + wantErr: false, + }, + { + name: "third", + want: &pgtype.UUID{ + Bytes: [16]byte{}, + Valid: false, + }, + src: []byte("1d485a7a-6d18-4599-8c6c-34425616887a"), + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := &pgtype.UUID{} + if err := got.UnmarshalJSON(tt.src); (err != nil) != tt.wantErr { + t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("UnmarshalJSON() got = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pgtype/zeronull/doc.go b/pgtype/zeronull/doc.go new file mode 100644 index 00000000..78a52307 --- /dev/null +++ b/pgtype/zeronull/doc.go @@ -0,0 +1,22 @@ +// Package zeronull contains types that automatically convert between database NULLs and Go zero values. +/* +Sometimes the distinction between a zero value and a NULL value is not useful at the application level. For example, +in PostgreSQL an empty string may be stored as NULL. There is usually no application level distinction between an +empty string and a NULL string. Package zeronull implements types that seamlessly convert between PostgreSQL NULL and +the zero value. + +It is recommended to convert types at usage time rather than instantiate these types directly. In the example below, +middlename would be stored as a NULL. + + firstname := "John" + middlename := "" + lastname := "Smith" + _, err := conn.Exec( + ctx, + "insert into people(firstname, middlename, lastname) values($1, $2, $3)", + zeronull.Text(firstname), + zeronull.Text(middlename), + zeronull.Text(lastname), + ) +*/ +package zeronull diff --git a/pgtype/zeronull/float8.go b/pgtype/zeronull/float8.go new file mode 100644 index 00000000..08fa169e --- /dev/null +++ b/pgtype/zeronull/float8.go @@ -0,0 +1,56 @@ +package zeronull + +import ( + "database/sql/driver" + + "github.com/jackc/pgx/v5/pgtype" +) + +type Float8 float64 + +func (Float8) SkipUnderlyingTypePlan() {} + +// ScanFloat64 implements the Float64Scanner interface. +func (f *Float8) ScanFloat64(n pgtype.Float8) error { + if !n.Valid { + *f = 0 + return nil + } + + *f = Float8(n.Float64) + + return nil +} + +func (f Float8) Float64Value() (pgtype.Float8, error) { + if f == 0 { + return pgtype.Float8{}, nil + } + return pgtype.Float8{Float64: float64(f), Valid: true}, nil +} + +// Scan implements the database/sql Scanner interface. +func (f *Float8) Scan(src any) error { + if src == nil { + *f = 0 + return nil + } + + var nullable pgtype.Float8 + err := nullable.Scan(src) + if err != nil { + return err + } + + *f = Float8(nullable.Float64) + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (f Float8) Value() (driver.Value, error) { + if f == 0 { + return nil, nil + } + return float64(f), nil +} diff --git a/pgtype/zeronull/float8_test.go b/pgtype/zeronull/float8_test.go new file mode 100644 index 00000000..b3c818aa --- /dev/null +++ b/pgtype/zeronull/float8_test.go @@ -0,0 +1,35 @@ +package zeronull_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5/pgtype/zeronull" + "github.com/jackc/pgx/v5/pgxtest" +) + +func isExpectedEq(a any) func(any) bool { + return func(v any) bool { + return a == v + } +} + +func TestFloat8Transcode(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "float8", []pgxtest.ValueRoundTripTest{ + { + (zeronull.Float8)(1), + new(zeronull.Float8), + isExpectedEq((zeronull.Float8)(1)), + }, + { + nil, + new(zeronull.Float8), + isExpectedEq((zeronull.Float8)(0)), + }, + { + (zeronull.Float8)(0), + new(any), + isExpectedEq(nil), + }, + }) +} diff --git a/pgtype/zeronull/int.go b/pgtype/zeronull/int.go new file mode 100644 index 00000000..4fec8a1a --- /dev/null +++ b/pgtype/zeronull/int.go @@ -0,0 +1,154 @@ +// Do not edit. Generated from pgtype/zeronull/int.go.erb +package zeronull + +import ( + "database/sql/driver" + "fmt" + "math" + + "github.com/jackc/pgx/v5/pgtype" +) + +type Int2 int16 + +func (Int2) SkipUnderlyingTypePlan() {} + +// ScanInt64 implements the Int64Scanner interface. +func (dst *Int2) ScanInt64(n int64, valid bool) error { + if !valid { + *dst = 0 + return nil + } + + if n < math.MinInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", n) + } + if n > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", n) + } + *dst = Int2(n) + + return nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int2) Scan(src any) error { + if src == nil { + *dst = 0 + return nil + } + + var nullable pgtype.Int2 + err := nullable.Scan(src) + if err != nil { + return err + } + + *dst = Int2(nullable.Int16) + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int2) Value() (driver.Value, error) { + if src == 0 { + return nil, nil + } + return int64(src), nil +} + +type Int4 int32 + +func (Int4) SkipUnderlyingTypePlan() {} + +// ScanInt64 implements the Int64Scanner interface. +func (dst *Int4) ScanInt64(n int64, valid bool) error { + if !valid { + *dst = 0 + return nil + } + + if n < math.MinInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", n) + } + if n > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", n) + } + *dst = Int4(n) + + return nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int4) Scan(src any) error { + if src == nil { + *dst = 0 + return nil + } + + var nullable pgtype.Int4 + err := nullable.Scan(src) + if err != nil { + return err + } + + *dst = Int4(nullable.Int32) + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int4) Value() (driver.Value, error) { + if src == 0 { + return nil, nil + } + return int64(src), nil +} + +type Int8 int64 + +func (Int8) SkipUnderlyingTypePlan() {} + +// ScanInt64 implements the Int64Scanner interface. +func (dst *Int8) ScanInt64(n int64, valid bool) error { + if !valid { + *dst = 0 + return nil + } + + if n < math.MinInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", n) + } + if n > math.MaxInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", n) + } + *dst = Int8(n) + + return nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int8) Scan(src any) error { + if src == nil { + *dst = 0 + return nil + } + + var nullable pgtype.Int8 + err := nullable.Scan(src) + if err != nil { + return err + } + + *dst = Int8(nullable.Int64) + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int8) Value() (driver.Value, error) { + if src == 0 { + return nil, nil + } + return int64(src), nil +} diff --git a/pgtype/zeronull/int.go.erb b/pgtype/zeronull/int.go.erb new file mode 100644 index 00000000..b51cba12 --- /dev/null +++ b/pgtype/zeronull/int.go.erb @@ -0,0 +1,60 @@ +package zeronull + +import ( + "database/sql/driver" + "fmt" + "math" + + "github.com/jackc/pgx/v5/pgtype" +) + +<% [2, 4, 8].each do |pg_byte_size| %> +<% pg_bit_size = pg_byte_size * 8 %> +type Int<%= pg_byte_size %> int<%= pg_bit_size %> + +func (Int<%= pg_byte_size %>) SkipUnderlyingTypePlan() {} + +// ScanInt64 implements the Int64Scanner interface. +func (dst *Int<%= pg_byte_size %>) ScanInt64(n int64, valid bool) error { + if !valid { + *dst = 0 + return nil + } + + if n < math.MinInt<%= pg_bit_size %> { + return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n) + } + if n > math.MaxInt<%= pg_bit_size %> { + return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n) + } + *dst = Int<%= pg_byte_size %>(n) + + return nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int<%= pg_byte_size %>) Scan(src any) error { + if src == nil { + *dst = 0 + return nil + } + + var nullable pgtype.Int<%= pg_byte_size %> + err := nullable.Scan(src) + if err != nil { + return err + } + + *dst = Int<%= pg_byte_size %>(nullable.Int<%= pg_bit_size %>) + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int<%= pg_byte_size %>) Value() (driver.Value, error) { + if src == 0 { + return nil, nil + } + return int64(src), nil +} +<% end %> diff --git a/pgtype/zeronull/int_test.go b/pgtype/zeronull/int_test.go new file mode 100644 index 00000000..7204cc88 --- /dev/null +++ b/pgtype/zeronull/int_test.go @@ -0,0 +1,70 @@ +// Do not edit. Generated from pgtype/zeronull/int_test.go.erb +package zeronull_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5/pgtype/zeronull" + "github.com/jackc/pgx/v5/pgxtest" +) + +func TestInt2Transcode(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int2", []pgxtest.ValueRoundTripTest{ + { + (zeronull.Int2)(1), + new(zeronull.Int2), + isExpectedEq((zeronull.Int2)(1)), + }, + { + nil, + new(zeronull.Int2), + isExpectedEq((zeronull.Int2)(0)), + }, + { + (zeronull.Int2)(0), + new(any), + isExpectedEq(nil), + }, + }) +} + +func TestInt4Transcode(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int4", []pgxtest.ValueRoundTripTest{ + { + (zeronull.Int4)(1), + new(zeronull.Int4), + isExpectedEq((zeronull.Int4)(1)), + }, + { + nil, + new(zeronull.Int4), + isExpectedEq((zeronull.Int4)(0)), + }, + { + (zeronull.Int4)(0), + new(any), + isExpectedEq(nil), + }, + }) +} + +func TestInt8Transcode(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int8", []pgxtest.ValueRoundTripTest{ + { + (zeronull.Int8)(1), + new(zeronull.Int8), + isExpectedEq((zeronull.Int8)(1)), + }, + { + nil, + new(zeronull.Int8), + isExpectedEq((zeronull.Int8)(0)), + }, + { + (zeronull.Int8)(0), + new(any), + isExpectedEq(nil), + }, + }) +} diff --git a/pgtype/zeronull/int_test.go.erb b/pgtype/zeronull/int_test.go.erb new file mode 100644 index 00000000..c0f72ef4 --- /dev/null +++ b/pgtype/zeronull/int_test.go.erb @@ -0,0 +1,31 @@ +package zeronull_test + +import ( + "testing" + + "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype/zeronull" +) + +<% [2, 4, 8].each do |pg_byte_size| %> +<% pg_bit_size = pg_byte_size * 8 %> +func TestInt<%= pg_byte_size %>Transcode(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int<%= pg_byte_size %>", []pgxtest.ValueRoundTripTest{ + { + (zeronull.Int<%= pg_byte_size %>)(1), + new(zeronull.Int<%= pg_byte_size %>), + isExpectedEq((zeronull.Int<%= pg_byte_size %>)(1)), + }, + { + nil, + new(zeronull.Int<%= pg_byte_size %>), + isExpectedEq((zeronull.Int<%= pg_byte_size %>)(0)), + }, + { + (zeronull.Int<%= pg_byte_size %>)(0), + new(any), + isExpectedEq(nil), + }, + }) +} +<% end %> diff --git a/pgtype/zeronull/text.go b/pgtype/zeronull/text.go new file mode 100644 index 00000000..4ba51fa9 --- /dev/null +++ b/pgtype/zeronull/text.go @@ -0,0 +1,49 @@ +package zeronull + +import ( + "database/sql/driver" + + "github.com/jackc/pgx/v5/pgtype" +) + +type Text string + +func (Text) SkipUnderlyingTypePlan() {} + +// ScanText implements the TextScanner interface. +func (dst *Text) ScanText(v pgtype.Text) error { + if !v.Valid { + *dst = "" + return nil + } + + *dst = Text(v.String) + + return nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Text) Scan(src any) error { + if src == nil { + *dst = "" + return nil + } + + var nullable pgtype.Text + err := nullable.Scan(src) + if err != nil { + return err + } + + *dst = Text(nullable.String) + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (src Text) Value() (driver.Value, error) { + if src == "" { + return nil, nil + } + return string(src), nil +} diff --git a/pgtype/zeronull/text_test.go b/pgtype/zeronull/text_test.go new file mode 100644 index 00000000..5a60baf1 --- /dev/null +++ b/pgtype/zeronull/text_test.go @@ -0,0 +1,29 @@ +package zeronull_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5/pgtype/zeronull" + "github.com/jackc/pgx/v5/pgxtest" +) + +func TestTextTranscode(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "text", []pgxtest.ValueRoundTripTest{ + { + (zeronull.Text)("foo"), + new(zeronull.Text), + isExpectedEq((zeronull.Text)("foo")), + }, + { + nil, + new(zeronull.Text), + isExpectedEq((zeronull.Text)("")), + }, + { + (zeronull.Text)(""), + new(any), + isExpectedEq(nil), + }, + }) +} diff --git a/pgtype/zeronull/timestamp.go b/pgtype/zeronull/timestamp.go new file mode 100644 index 00000000..1697c420 --- /dev/null +++ b/pgtype/zeronull/timestamp.go @@ -0,0 +1,67 @@ +package zeronull + +import ( + "database/sql/driver" + "fmt" + "time" + + "github.com/jackc/pgx/v5/pgtype" +) + +type Timestamp time.Time + +func (Timestamp) SkipUnderlyingTypePlan() {} + +func (ts *Timestamp) ScanTimestamp(v pgtype.Timestamp) error { + if !v.Valid { + *ts = Timestamp{} + return nil + } + + switch v.InfinityModifier { + case pgtype.Finite: + *ts = Timestamp(v.Time) + return nil + case pgtype.Infinity: + return fmt.Errorf("cannot scan Infinity into *time.Time") + case pgtype.NegativeInfinity: + return fmt.Errorf("cannot scan -Infinity into *time.Time") + default: + return fmt.Errorf("invalid InfinityModifier: %v", v.InfinityModifier) + } +} + +func (ts Timestamp) TimestampValue() (pgtype.Timestamp, error) { + if time.Time(ts).IsZero() { + return pgtype.Timestamp{}, nil + } + + return pgtype.Timestamp{Time: time.Time(ts), Valid: true}, nil +} + +// Scan implements the database/sql Scanner interface. +func (ts *Timestamp) Scan(src any) error { + if src == nil { + *ts = Timestamp{} + return nil + } + + var nullable pgtype.Timestamp + err := nullable.Scan(src) + if err != nil { + return err + } + + *ts = Timestamp(nullable.Time) + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (ts Timestamp) Value() (driver.Value, error) { + if time.Time(ts).IsZero() { + return nil, nil + } + + return time.Time(ts), nil +} diff --git a/pgtype/zeronull/timestamp_test.go b/pgtype/zeronull/timestamp_test.go new file mode 100644 index 00000000..8a5a5796 --- /dev/null +++ b/pgtype/zeronull/timestamp_test.go @@ -0,0 +1,39 @@ +package zeronull_test + +import ( + "context" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgtype/zeronull" + "github.com/jackc/pgx/v5/pgxtest" +) + +func isExpectedEqTimestamp(a any) func(any) bool { + return func(v any) bool { + at := time.Time(a.(zeronull.Timestamp)) + vt := time.Time(v.(zeronull.Timestamp)) + + return at.Equal(vt) + } +} + +func TestTimestampTranscode(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "timestamp", []pgxtest.ValueRoundTripTest{ + { + (zeronull.Timestamp)(time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)), + new(zeronull.Timestamp), + isExpectedEqTimestamp((zeronull.Timestamp)(time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC))), + }, + { + nil, + new(zeronull.Timestamp), + isExpectedEqTimestamp((zeronull.Timestamp)(time.Time{})), + }, + { + (zeronull.Timestamp)(time.Time{}), + new(any), + isExpectedEq(nil), + }, + }) +} diff --git a/pgtype/zeronull/timestamptz.go b/pgtype/zeronull/timestamptz.go new file mode 100644 index 00000000..55bc0c8e --- /dev/null +++ b/pgtype/zeronull/timestamptz.go @@ -0,0 +1,67 @@ +package zeronull + +import ( + "database/sql/driver" + "fmt" + "time" + + "github.com/jackc/pgx/v5/pgtype" +) + +type Timestamptz time.Time + +func (Timestamptz) SkipUnderlyingTypePlan() {} + +func (ts *Timestamptz) ScanTimestamptz(v pgtype.Timestamptz) error { + if !v.Valid { + *ts = Timestamptz{} + return nil + } + + switch v.InfinityModifier { + case pgtype.Finite: + *ts = Timestamptz(v.Time) + return nil + case pgtype.Infinity: + return fmt.Errorf("cannot scan Infinity into *time.Time") + case pgtype.NegativeInfinity: + return fmt.Errorf("cannot scan -Infinity into *time.Time") + default: + return fmt.Errorf("invalid InfinityModifier: %v", v.InfinityModifier) + } +} + +func (ts Timestamptz) TimestamptzValue() (pgtype.Timestamptz, error) { + if time.Time(ts).IsZero() { + return pgtype.Timestamptz{}, nil + } + + return pgtype.Timestamptz{Time: time.Time(ts), Valid: true}, nil +} + +// Scan implements the database/sql Scanner interface. +func (ts *Timestamptz) Scan(src any) error { + if src == nil { + *ts = Timestamptz{} + return nil + } + + var nullable pgtype.Timestamp + err := nullable.Scan(src) + if err != nil { + return err + } + + *ts = Timestamptz(nullable.Time) + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (ts Timestamptz) Value() (driver.Value, error) { + if time.Time(ts).IsZero() { + return nil, nil + } + + return time.Time(ts), nil +} diff --git a/pgtype/zeronull/timestamptz_test.go b/pgtype/zeronull/timestamptz_test.go new file mode 100644 index 00000000..0a6d380b --- /dev/null +++ b/pgtype/zeronull/timestamptz_test.go @@ -0,0 +1,39 @@ +package zeronull_test + +import ( + "context" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgtype/zeronull" + "github.com/jackc/pgx/v5/pgxtest" +) + +func isExpectedEqTimestamptz(a any) func(any) bool { + return func(v any) bool { + at := time.Time(a.(zeronull.Timestamptz)) + vt := time.Time(v.(zeronull.Timestamptz)) + + return at.Equal(vt) + } +} + +func TestTimestamptzTranscode(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "timestamptz", []pgxtest.ValueRoundTripTest{ + { + (zeronull.Timestamptz)(time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)), + new(zeronull.Timestamptz), + isExpectedEqTimestamptz((zeronull.Timestamptz)(time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC))), + }, + { + nil, + new(zeronull.Timestamptz), + isExpectedEqTimestamptz((zeronull.Timestamptz)(time.Time{})), + }, + { + (zeronull.Timestamptz)(time.Time{}), + new(any), + isExpectedEq(nil), + }, + }) +} diff --git a/pgtype/zeronull/uuid.go b/pgtype/zeronull/uuid.go new file mode 100644 index 00000000..d88be84d --- /dev/null +++ b/pgtype/zeronull/uuid.go @@ -0,0 +1,62 @@ +package zeronull + +import ( + "database/sql/driver" + + "github.com/jackc/pgx/v5/pgtype" +) + +type UUID [16]byte + +func (UUID) SkipUnderlyingTypePlan() {} + +// ScanUUID implements the UUIDScanner interface. +func (u *UUID) ScanUUID(v pgtype.UUID) error { + if !v.Valid { + *u = UUID{} + return nil + } + + *u = UUID(v.Bytes) + + return nil +} + +func (u UUID) UUIDValue() (pgtype.UUID, error) { + if u == (UUID{}) { + return pgtype.UUID{}, nil + } + return pgtype.UUID{Bytes: u, Valid: true}, nil +} + +// Scan implements the database/sql Scanner interface. +func (u *UUID) Scan(src any) error { + if src == nil { + *u = UUID{} + return nil + } + + var nullable pgtype.UUID + err := nullable.Scan(src) + if err != nil { + return err + } + + *u = UUID(nullable.Bytes) + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (u UUID) Value() (driver.Value, error) { + if u == (UUID{}) { + return nil, nil + } + + buf, err := pgtype.UUIDCodec{}.PlanEncode(nil, pgtype.UUIDOID, pgtype.TextFormatCode, u).Encode(u, nil) + if err != nil { + return nil, err + } + + return string(buf), nil +} diff --git a/pgtype/zeronull/uuid_test.go b/pgtype/zeronull/uuid_test.go new file mode 100644 index 00000000..c50cb300 --- /dev/null +++ b/pgtype/zeronull/uuid_test.go @@ -0,0 +1,29 @@ +package zeronull_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5/pgtype/zeronull" + "github.com/jackc/pgx/v5/pgxtest" +) + +func TestUUIDTranscode(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "uuid", []pgxtest.ValueRoundTripTest{ + { + (zeronull.UUID)([16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}), + new(zeronull.UUID), + isExpectedEq((zeronull.UUID)([16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15})), + }, + { + nil, + new(zeronull.UUID), + isExpectedEq((zeronull.UUID)([16]byte{})), + }, + { + (zeronull.UUID)([16]byte{}), + new(any), + isExpectedEq(nil), + }, + }) +} diff --git a/pgtype/zeronull/zeronull.go b/pgtype/zeronull/zeronull.go new file mode 100644 index 00000000..bba7b423 --- /dev/null +++ b/pgtype/zeronull/zeronull.go @@ -0,0 +1,17 @@ +package zeronull + +import ( + "github.com/jackc/pgx/v5/pgtype" +) + +// Register registers the zeronull types so they can be used in query exec modes that do not know the server OIDs. +func Register(m *pgtype.Map) { + m.RegisterDefaultPgType(Float8(0), "float8") + m.RegisterDefaultPgType(Int2(0), "int2") + m.RegisterDefaultPgType(Int4(0), "int4") + m.RegisterDefaultPgType(Int8(0), "int8") + m.RegisterDefaultPgType(Text(""), "text") + m.RegisterDefaultPgType(Timestamp{}, "timestamp") + m.RegisterDefaultPgType(Timestamptz{}, "timestamptz") + m.RegisterDefaultPgType(UUID{}, "uuid") +} diff --git a/pgtype/zeronull/zeronull_test.go b/pgtype/zeronull/zeronull_test.go new file mode 100644 index 00000000..9ee45cb7 --- /dev/null +++ b/pgtype/zeronull/zeronull_test.go @@ -0,0 +1,26 @@ +package zeronull_test + +import ( + "context" + "os" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype/zeronull" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/require" +) + +var defaultConnTestRunner pgxtest.ConnTestRunner + +func init() { + defaultConnTestRunner = pgxtest.DefaultConnTestRunner() + defaultConnTestRunner.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + return config + } + defaultConnTestRunner.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + zeronull.Register(conn.TypeMap()) + } +} diff --git a/pgxpool/batch_results.go b/pgxpool/batch_results.go index c625a474..5d5c681d 100644 --- a/pgxpool/batch_results.go +++ b/pgxpool/batch_results.go @@ -1,8 +1,8 @@ package pgxpool import ( - "github.com/jackc/pgconn" - "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" ) type errBatchResults struct { @@ -10,17 +10,13 @@ type errBatchResults struct { } func (br errBatchResults) Exec() (pgconn.CommandTag, error) { - return nil, br.err + return pgconn.CommandTag{}, br.err } func (br errBatchResults) Query() (pgx.Rows, error) { return errRows{err: br.err}, br.err } -func (br errBatchResults) QueryFunc(scans []interface{}, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) { - return nil, br.err -} - func (br errBatchResults) QueryRow() pgx.Row { return errRow{err: br.err} } @@ -42,10 +38,6 @@ func (br *poolBatchResults) Query() (pgx.Rows, error) { return br.br.Query() } -func (br *poolBatchResults) QueryFunc(scans []interface{}, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) { - return br.br.QueryFunc(scans, f) -} - func (br *poolBatchResults) QueryRow() pgx.Row { return br.br.QueryRow() } diff --git a/pgxpool/bench_test.go b/pgxpool/bench_test.go index 9ec63ca3..c2d58a38 100644 --- a/pgxpool/bench_test.go +++ b/pgxpool/bench_test.go @@ -5,13 +5,13 @@ import ( "os" "testing" - "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/pgxpool" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" "github.com/stretchr/testify/require" ) func BenchmarkAcquireAndRelease(b *testing.B) { - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(b, err) defer pool.Close() @@ -34,7 +34,7 @@ func BenchmarkMinimalPreparedSelectBaseline(b *testing.B) { return err } - db, err := pgxpool.ConnectConfig(context.Background(), config) + db, err := pgxpool.NewWithConfig(context.Background(), config) require.NoError(b, err) conn, err := db.Acquire(context.Background()) @@ -65,7 +65,7 @@ func BenchmarkMinimalPreparedSelect(b *testing.B) { return err } - db, err := pgxpool.ConnectConfig(context.Background(), config) + db, err := pgxpool.NewWithConfig(context.Background(), config) require.NoError(b, err) var n int64 diff --git a/pgxpool/common_test.go b/pgxpool/common_test.go index c701e1f7..16f4f553 100644 --- a/pgxpool/common_test.go +++ b/pgxpool/common_test.go @@ -5,10 +5,10 @@ import ( "testing" "time" - "github.com/jackc/pgx/v4/pgxpool" + "github.com/jackc/pgx/v5/pgxpool" - "github.com/jackc/pgconn" - "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -21,17 +21,17 @@ func waitForReleaseToComplete() { } type execer interface { - Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) + Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) } func testExec(t *testing.T, db execer) { results, err := db.Exec(context.Background(), "set time zone 'America/Chicago'") require.NoError(t, err) - assert.EqualValues(t, "SET", results) + assert.EqualValues(t, "SET", results.String()) } type queryer interface { - Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) + Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) } func testQuery(t *testing.T, db queryer) { @@ -53,7 +53,7 @@ func testQuery(t *testing.T, db queryer) { } type queryRower interface { - QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row + QueryRow(ctx context.Context, sql string, args ...any) pgx.Row } func testQueryRow(t *testing.T, db queryRower) { @@ -103,7 +103,7 @@ func testCopyFrom(t *testing.T, db interface { tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local) - inputRows := [][]interface{}{ + inputRows := [][]any{ {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime}, {nil, nil, nil, nil, nil, nil, nil}, } @@ -115,7 +115,7 @@ func testCopyFrom(t *testing.T, db interface { rows, err := db.Query(context.Background(), "select * from foo") assert.NoError(t, err) - var outputRows [][]interface{} + var outputRows [][]any for rows.Next() { row, err := rows.Values() if err != nil { @@ -148,7 +148,6 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgxpool.Config, testName assert.Equalf(t, expected.MaxConns, actual.MaxConns, "%s - MaxConns", testName) assert.Equalf(t, expected.MinConns, actual.MinConns, "%s - MinConns", testName) assert.Equalf(t, expected.HealthCheckPeriod, actual.HealthCheckPeriod, "%s - HealthCheckPeriod", testName) - assert.Equalf(t, expected.LazyConnect, actual.LazyConnect, "%s - LazyConnect", testName) assertConnConfigsEqual(t, expected.ConnConfig, actual.ConnConfig, testName) } @@ -161,15 +160,12 @@ func assertConnConfigsEqual(t *testing.T, expected, actual *pgx.ConnConfig, test return } - assert.Equalf(t, expected.Logger, actual.Logger, "%s - Logger", testName) - assert.Equalf(t, expected.LogLevel, actual.LogLevel, "%s - LogLevel", testName) + assert.Equalf(t, expected.Tracer, actual.Tracer, "%s - Tracer", testName) assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName) - - // Can't test function equality, so just test that they are set or not. - assert.Equalf(t, expected.BuildStatementCache == nil, actual.BuildStatementCache == nil, "%s - BuildStatementCache", testName) - - assert.Equalf(t, expected.PreferSimpleProtocol, actual.PreferSimpleProtocol, "%s - PreferSimpleProtocol", testName) - + assert.Equalf(t, expected.StatementCacheCapacity, actual.StatementCacheCapacity, "%s - StatementCacheCapacity", testName) + assert.Equalf(t, expected.DescriptionCacheCapacity, actual.DescriptionCacheCapacity, "%s - DescriptionCacheCapacity", testName) + assert.Equalf(t, expected.DefaultQueryExecMode, actual.DefaultQueryExecMode, "%s - DefaultQueryExecMode", testName) + assert.Equalf(t, expected.DefaultQueryExecMode, actual.DefaultQueryExecMode, "%s - DefaultQueryExecMode", testName) assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName) assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName) assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName) diff --git a/pgxpool/conn.go b/pgxpool/conn.go index 6482c821..36f90969 100644 --- a/pgxpool/conn.go +++ b/pgxpool/conn.go @@ -4,14 +4,14 @@ import ( "context" "sync/atomic" - "github.com/jackc/pgconn" - "github.com/jackc/pgx/v4" - "github.com/jackc/puddle" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/puddle/v2" ) // Conn is an acquired *pgx.Conn from a Pool. type Conn struct { - res *puddle.Resource + res *puddle.Resource[*connResource] p *Pool } @@ -79,22 +79,18 @@ func (c *Conn) Hijack() *pgx.Conn { return conn } -func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) { +func (c *Conn) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) { return c.Conn().Exec(ctx, sql, arguments...) } -func (c *Conn) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) { +func (c *Conn) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) { return c.Conn().Query(ctx, sql, args...) } -func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row { +func (c *Conn) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row { return c.Conn().QueryRow(ctx, sql, args...) } -func (c *Conn) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) { - return c.Conn().QueryFunc(ctx, sql, args, scans, f) -} - func (c *Conn) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults { return c.Conn().SendBatch(ctx, b) } @@ -113,14 +109,6 @@ func (c *Conn) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, er return c.Conn().BeginTx(ctx, txOptions) } -func (c *Conn) BeginFunc(ctx context.Context, f func(pgx.Tx) error) error { - return c.Conn().BeginFunc(ctx, f) -} - -func (c *Conn) BeginTxFunc(ctx context.Context, txOptions pgx.TxOptions, f func(pgx.Tx) error) error { - return c.Conn().BeginTxFunc(ctx, txOptions, f) -} - func (c *Conn) Ping(ctx context.Context) error { return c.Conn().Ping(ctx) } @@ -130,7 +118,7 @@ func (c *Conn) Conn() *pgx.Conn { } func (c *Conn) connResource() *connResource { - return c.res.Value().(*connResource) + return c.res.Value() } func (c *Conn) getPoolRow(r pgx.Row) *poolRow { diff --git a/pgxpool/conn_test.go b/pgxpool/conn_test.go index c03ae13e..175981b7 100644 --- a/pgxpool/conn_test.go +++ b/pgxpool/conn_test.go @@ -5,14 +5,14 @@ import ( "os" "testing" - "github.com/jackc/pgx/v4/pgxpool" + "github.com/jackc/pgx/v5/pgxpool" "github.com/stretchr/testify/require" ) func TestConnExec(t *testing.T) { t.Parallel() - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -26,7 +26,7 @@ func TestConnExec(t *testing.T) { func TestConnQuery(t *testing.T) { t.Parallel() - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -40,7 +40,7 @@ func TestConnQuery(t *testing.T) { func TestConnQueryRow(t *testing.T) { t.Parallel() - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -54,7 +54,7 @@ func TestConnQueryRow(t *testing.T) { func TestConnSendBatch(t *testing.T) { t.Parallel() - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -68,7 +68,7 @@ func TestConnSendBatch(t *testing.T) { func TestConnCopyFrom(t *testing.T) { t.Parallel() - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() diff --git a/pgxpool/doc.go b/pgxpool/doc.go index e8239a6f..38e49795 100644 --- a/pgxpool/doc.go +++ b/pgxpool/doc.go @@ -2,11 +2,11 @@ /* pgxpool implements a nearly identical interface to pgx connections. -Establishing a Connection +Creating a Pool -The primary way of establishing a connection is with `pgxpool.Connect`. +The primary way of creating a pool is with `pgxpool.New`. - pool, err := pgxpool.Connect(context.Background(), os.Getenv("DATABASE_URL")) + pool, err := pgxpool.New(context.Background(), os.Getenv("DATABASE_URL")) The database connection string can be in URL or DSN format. PostgreSQL settings, pgx settings, and pool settings can be specified here. In addition, a config struct can be created by `ParseConfig` and modified before establishing the @@ -20,6 +20,9 @@ connection with `ConnectConfig`. // do something with every new connection } - pool, err := pgxpool.ConnectConfig(context.Background(), config) + pool, err := pgxpool.NewWithConfig(context.Background(), config) + +A pool returns without waiting for any connections to be established. Acquire a connection immediately after creating +the pool to check if a connection can successfully be established. */ package pgxpool diff --git a/pgxpool/pool.go b/pgxpool/pool.go index 8e88ecaf..236ba000 100644 --- a/pgxpool/pool.go +++ b/pgxpool/pool.go @@ -10,9 +10,9 @@ import ( "sync/atomic" "time" - "github.com/jackc/pgconn" - "github.com/jackc/pgx/v4" - "github.com/jackc/puddle" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/puddle/v2" ) var defaultMaxConns = int32(4) @@ -28,7 +28,7 @@ type connResource struct { poolRowss []poolRows } -func (cr *connResource) getConn(p *Pool, res *puddle.Resource) *Conn { +func (cr *connResource) getConn(p *Pool, res *puddle.Resource[*connResource]) *Conn { if len(cr.conns) == 0 { cr.conns = make([]Conn, 128) } @@ -70,16 +70,6 @@ func (cr *connResource) getPoolRows(c *Conn, r pgx.Rows) *poolRows { return pr } -// detachedCtx wraps a context and will never be canceled, regardless of if -// the wrapped one is cancelled. The Err() method will never return any errors. -type detachedCtx struct { - context.Context -} - -func (detachedCtx) Done() <-chan struct{} { return nil } -func (detachedCtx) Deadline() (time.Time, bool) { return time.Time{}, false } -func (detachedCtx) Err() error { return nil } - // Pool allows for connection reuse. type Pool struct { // 64 bit fields accessed with atomics must be at beginning of struct to guarantee alignment for certain 32-bit @@ -88,7 +78,7 @@ type Pool struct { lifetimeDestroyCount int64 idleDestroyCount int64 - p *puddle.Pool + p *puddle.Pool[*connResource] config *Config beforeConnect func(context.Context, *pgx.ConnConfig) error afterConnect func(context.Context, *pgx.Conn) error @@ -100,7 +90,8 @@ type Pool struct { maxConnLifetimeJitter time.Duration maxConnIdleTime time.Duration healthCheckPeriod time.Duration - healthCheckChan chan struct{} + + healthCheckChan chan struct{} closeOnce sync.Once closeChan chan struct{} @@ -148,11 +139,6 @@ type Config struct { // HealthCheckPeriod is the duration between checks of the health of idle connections. HealthCheckPeriod time.Duration - // If set to true, pool doesn't do any I/O operation on initialization. - // And connects to the server only when the pool starts to be used. - // The default is false. - LazyConnect bool - createdByParseConfig bool // Used to enforce created by ParseConfig rule. } @@ -169,20 +155,18 @@ func (c *Config) Copy() *Config { // ConnString returns the connection string as parsed by pgxpool.ParseConfig into pgxpool.Config. func (c *Config) ConnString() string { return c.ConnConfig.ConnString() } -// Connect creates a new Pool and immediately establishes one connection. ctx can be used to cancel this initial -// connection. See ParseConfig for information on connString format. -func Connect(ctx context.Context, connString string) (*Pool, error) { +// New creates a new Pool. See ParseConfig for information on connString format. +func New(ctx context.Context, connString string) (*Pool, error) { config, err := ParseConfig(connString) if err != nil { return nil, err } - return ConnectConfig(ctx, config) + return NewWithConfig(ctx, config) } -// ConnectConfig creates a new Pool and immediately establishes one connection. ctx can be used to cancel this initial -// connection. config must have been created by ParseConfig. -func ConnectConfig(ctx context.Context, config *Config) (*Pool, error) { +// NewWithConfig creates a new Pool. config must have been created by ParseConfig. +func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { // Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from // zero values. if !config.createdByParseConfig { @@ -205,82 +189,66 @@ func ConnectConfig(ctx context.Context, config *Config) (*Pool, error) { closeChan: make(chan struct{}), } - p.p = puddle.NewPool( - func(ctx context.Context) (interface{}, error) { - // we ignore cancellation on the original context because its either from - // the health check or its from a query and we don't want to cancel creating - // a connection just because the original query was cancelled since that - // could end up stampeding the server - // this will keep any Values in the original context and will just ignore - // cancellation - // see https://github.com/jackc/pgx/issues/1259 - ctx = detachedCtx{ctx} + var err error + p.p, err = puddle.NewPool( + &puddle.Config[*connResource]{ + Constructor: func(ctx context.Context) (*connResource, error) { + connConfig := p.config.ConnConfig.Copy() - connConfig := p.config.ConnConfig.Copy() - - // But we do want to ensure that a connect won't hang forever. - if connConfig.ConnectTimeout <= 0 { - connConfig.ConnectTimeout = 2 * time.Minute - } - - if p.beforeConnect != nil { - if err := p.beforeConnect(ctx, connConfig); err != nil { - return nil, err + // Connection will continue in background even if Acquire is canceled. Ensure that a connect won't hang forever. + if connConfig.ConnectTimeout <= 0 { + connConfig.ConnectTimeout = 2 * time.Minute } - } - conn, err := pgx.ConnectConfig(ctx, connConfig) - if err != nil { - return nil, err - } + if p.beforeConnect != nil { + if err := p.beforeConnect(ctx, connConfig); err != nil { + return nil, err + } + } - if p.afterConnect != nil { - err = p.afterConnect(ctx, conn) + conn, err := pgx.ConnectConfig(ctx, connConfig) if err != nil { - conn.Close(ctx) return nil, err } - } - cr := &connResource{ - conn: conn, - conns: make([]Conn, 64), - poolRows: make([]poolRow, 64), - poolRowss: make([]poolRows, 64), - } + if p.afterConnect != nil { + err = p.afterConnect(ctx, conn) + if err != nil { + conn.Close(ctx) + return nil, err + } + } - return cr, nil + cr := &connResource{ + conn: conn, + conns: make([]Conn, 64), + poolRows: make([]poolRow, 64), + poolRowss: make([]poolRows, 64), + } + + return cr, nil + }, + Destructor: func(value *connResource) { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + conn := value.conn + conn.Close(ctx) + select { + case <-conn.PgConn().CleanupDone(): + case <-ctx.Done(): + } + cancel() + }, + MaxSize: config.MaxConns, }, - func(value interface{}) { - ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) - conn := value.(*connResource).conn - conn.Close(ctx) - select { - case <-conn.PgConn().CleanupDone(): - case <-ctx.Done(): - } - cancel() - }, - config.MaxConns, ) - - if !config.LazyConnect { - if err := p.checkMinConns(); err != nil { - // Couldn't create resources for minpool size. Close unhealthy pool. - p.Close() - return nil, err - } - - // Initially establish one connection - res, err := p.p.Acquire(ctx) - if err != nil { - p.Close() - return nil, err - } - res.Release() + if err != nil { + return nil, err } - go p.backgroundHealthCheck() + go func() { + p.createIdleResources(ctx, int(p.minConns)) + p.backgroundHealthCheck() + }() return p, nil } @@ -395,7 +363,7 @@ func (p *Pool) Close() { }) } -func (p *Pool) isExpired(res *puddle.Resource) bool { +func (p *Pool) isExpired(res *puddle.Resource[*connResource]) bool { now := time.Now() // Small optimization to avoid rand. If it's over lifetime AND jitter, immediately // return true. @@ -530,7 +498,16 @@ func (p *Pool) Acquire(ctx context.Context) (*Conn, error) { return nil, err } - cr := res.Value().(*connResource) + cr := res.Value() + + if res.IdleDuration() > time.Second { + err := cr.conn.PgConn().CheckConn() + if err != nil { + res.Destroy() + continue + } + } + if p.beforeAcquire == nil || p.beforeAcquire(ctx, cr.conn) { return cr.getConn(p, res), nil } @@ -558,7 +535,7 @@ func (p *Pool) AcquireAllIdle(ctx context.Context) []*Conn { resources := p.p.AcquireAllIdle() conns := make([]*Conn, 0, len(resources)) for _, res := range resources { - cr := res.Value().(*connResource) + cr := res.Value() if p.beforeAcquire == nil || p.beforeAcquire(ctx, cr.conn) { conns = append(conns, cr.getConn(p, res)) } else { @@ -569,6 +546,15 @@ func (p *Pool) AcquireAllIdle(ctx context.Context) []*Conn { return conns } +// Reset closes all connections, but leaves the pool open. It is intended for use when an error is detected that would +// disrupt all connections (such as a network interruption or a server state change). +// +// It is safe to reset a pool while connections are checked out. Those connections will be closed when they are returned +// to the pool. +func (p *Pool) Reset() { + p.p.Reset() +} + // Config returns a copy of config that was used to initialize this pool. func (p *Pool) Config() *Config { return p.config.Copy() } @@ -586,10 +572,10 @@ func (p *Pool) Stat() *Stat { // SQL can be either a prepared statement name or an SQL string. // Arguments should be referenced positionally from the SQL string as $1, $2, etc. // The acquired connection is returned to the pool when the Exec function returns. -func (p *Pool) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) { +func (p *Pool) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) { c, err := p.Acquire(ctx) if err != nil { - return nil, err + return pgconn.CommandTag{}, err } defer c.Release() @@ -606,7 +592,7 @@ func (p *Pool) Exec(ctx context.Context, sql string, arguments ...interface{}) ( // For extra control over how the query is executed, the types QuerySimpleProtocol, QueryResultFormats, and // QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely // needed. See the documentation for those types for details. -func (p *Pool) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) { +func (p *Pool) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) { c, err := p.Acquire(ctx) if err != nil { return errRows{err: err}, err @@ -633,7 +619,7 @@ func (p *Pool) Query(ctx context.Context, sql string, args ...interface{}) (pgx. // For extra control over how the query is executed, the types QuerySimpleProtocol, QueryResultFormats, and // QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely // needed. See the documentation for those types for details. -func (p *Pool) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row { +func (p *Pool) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row { c, err := p.Acquire(ctx) if err != nil { return errRow{err: err} @@ -643,16 +629,6 @@ func (p *Pool) QueryRow(ctx context.Context, sql string, args ...interface{}) pg return c.getPoolRow(row) } -func (p *Pool) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) { - c, err := p.Acquire(ctx) - if err != nil { - return nil, err - } - defer c.Release() - - return c.QueryFunc(ctx, sql, args, scans, f) -} - func (p *Pool) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults { c, err := p.Acquire(ctx) if err != nil { @@ -690,20 +666,6 @@ func (p *Pool) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, er return &Tx{t: t, c: c}, nil } -func (p *Pool) BeginFunc(ctx context.Context, f func(pgx.Tx) error) error { - return p.BeginTxFunc(ctx, pgx.TxOptions{}, f) -} - -func (p *Pool) BeginTxFunc(ctx context.Context, txOptions pgx.TxOptions, f func(pgx.Tx) error) error { - c, err := p.Acquire(ctx) - if err != nil { - return err - } - defer c.Release() - - return c.BeginTxFunc(ctx, txOptions, f) -} - func (p *Pool) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { c, err := p.Acquire(ctx) if err != nil { diff --git a/pgxpool/pool_test.go b/pgxpool/pool_test.go index 4e712016..2ceb33cf 100644 --- a/pgxpool/pool_test.go +++ b/pgxpool/pool_test.go @@ -9,8 +9,9 @@ import ( "testing" "time" - "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/pgxpool" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/jackc/pgx/v5/pgxtest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -18,7 +19,7 @@ import ( func TestConnect(t *testing.T) { t.Parallel() connString := os.Getenv("PGX_TEST_DATABASE") - pool, err := pgxpool.Connect(context.Background(), connString) + pool, err := pgxpool.New(context.Background(), connString) require.NoError(t, err) assert.Equal(t, connString, pool.Config().ConnString()) pool.Close() @@ -29,7 +30,7 @@ func TestConnectConfig(t *testing.T) { connString := os.Getenv("PGX_TEST_DATABASE") config, err := pgxpool.ParseConfig(connString) require.NoError(t, err) - pool, err := pgxpool.ConnectConfig(context.Background(), config) + pool, err := pgxpool.NewWithConfig(context.Background(), config) require.NoError(t, err) assertConfigsEqual(t, config, pool.Config(), "Pool.Config() returns original config") pool.Close() @@ -46,39 +47,11 @@ func TestParseConfigExtractsPoolArguments(t *testing.T) { assert.NotContains(t, config.ConnConfig.Config.RuntimeParams, "pool_min_conns") } -func TestConnectCancel(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithCancel(context.Background()) - cancel() - pool, err := pgxpool.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) - assert.Nil(t, pool) - assert.Equal(t, context.Canceled, err) -} - -func TestLazyConnect(t *testing.T) { - t.Parallel() - - config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - assert.NoError(t, err) - config.LazyConnect = true - - ctx, cancel := context.WithCancel(context.Background()) - cancel() - - pool, err := pgxpool.ConnectConfig(ctx, config) - assert.NoError(t, err) - - _, err = pool.Exec(ctx, "SELECT 1") - assert.Equal(t, context.Canceled, err) -} - func TestConstructorIgnoresContext(t *testing.T) { t.Parallel() config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) assert.NoError(t, err) - config.LazyConnect = true var cancel func() config.BeforeConnect = func(context.Context, *pgx.ConnConfig) error { // cancel the query's context before we actually Dial to ensure the Dial's @@ -87,7 +60,7 @@ func TestConstructorIgnoresContext(t *testing.T) { return nil } - pool, err := pgxpool.ConnectConfig(context.Background(), config) + pool, err := pgxpool.NewWithConfig(context.Background(), config) require.NoError(t, err) assert.EqualValues(t, 0, pool.Stat().TotalConns()) @@ -105,7 +78,7 @@ func TestConnectConfigRequiresConnConfigFromParseConfig(t *testing.T) { config := &pgxpool.Config{} - require.PanicsWithValue(t, "config must be created by ParseConfig", func() { pgxpool.ConnectConfig(context.Background(), config) }) + require.PanicsWithValue(t, "config must be created by ParseConfig", func() { pgxpool.NewWithConfig(context.Background(), config) }) } func TestConfigCopyReturnsEqualConfig(t *testing.T) { @@ -125,7 +98,7 @@ func TestConfigCopyCanBeUsedToConnect(t *testing.T) { copied := original.Copy() assert.NotPanics(t, func() { - _, err = pgxpool.ConnectConfig(context.Background(), copied) + _, err = pgxpool.NewWithConfig(context.Background(), copied) }) assert.NoError(t, err) } @@ -133,7 +106,7 @@ func TestConfigCopyCanBeUsedToConnect(t *testing.T) { func TestPoolAcquireAndConnRelease(t *testing.T) { t.Parallel() - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -147,7 +120,7 @@ func TestPoolAcquireAndConnHijack(t *testing.T) { ctx := context.Background() - pool, err := pgxpool.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -168,10 +141,61 @@ func TestPoolAcquireAndConnHijack(t *testing.T) { require.Equal(t, int32(1), n) } +func TestPoolAcquireChecksIdleConns(t *testing.T) { + t.Parallel() + + controllerConn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer controllerConn.Close(context.Background()) + pgxtest.SkipCockroachDB(t, controllerConn, "Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)") + + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + var conns []*pgxpool.Conn + for i := 0; i < 3; i++ { + c, err := pool.Acquire(context.Background()) + require.NoError(t, err) + conns = append(conns, c) + } + + require.EqualValues(t, 3, pool.Stat().TotalConns()) + + var pids []uint32 + for _, c := range conns { + pids = append(pids, c.Conn().PgConn().PID()) + c.Release() + } + + _, err = controllerConn.Exec(context.Background(), `select pg_terminate_backend(n) from unnest($1::int[]) n`, pids) + require.NoError(t, err) + + // All conns are dead they don't know it and neither does the pool. + require.EqualValues(t, 3, pool.Stat().TotalConns()) + + // Wait long enough so the pool will realize it needs to check the connections. + time.Sleep(time.Second) + + // Pool should try all existing connections and find them dead, then create a new connection which should successfully ping. + err = pool.Ping(context.Background()) + require.NoError(t, err) + + // The original 3 conns should have been terminated and the a new conn established for the ping. + require.EqualValues(t, 1, pool.Stat().TotalConns()) + c, err := pool.Acquire(context.Background()) + require.NoError(t, err) + + cPID := c.Conn().PgConn().PID() + c.Release() + + require.NotContains(t, pids, cPID) +} + func TestPoolAcquireFunc(t *testing.T) { t.Parallel() - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -186,7 +210,7 @@ func TestPoolAcquireFunc(t *testing.T) { func TestPoolAcquireFuncReturnsFnError(t *testing.T) { t.Parallel() - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -207,7 +231,7 @@ func TestPoolBeforeConnect(t *testing.T) { return nil } - db, err := pgxpool.ConnectConfig(context.Background(), config) + db, err := pgxpool.NewWithConfig(context.Background(), config) require.NoError(t, err) defer db.Close() @@ -228,7 +252,7 @@ func TestPoolAfterConnect(t *testing.T) { return err } - db, err := pgxpool.ConnectConfig(context.Background(), config) + db, err := pgxpool.NewWithConfig(context.Background(), config) require.NoError(t, err) defer db.Close() @@ -251,7 +275,7 @@ func TestPoolBeforeAcquire(t *testing.T) { return acquireAttempts%2 == 0 } - db, err := pgxpool.ConnectConfig(context.Background(), config) + db, err := pgxpool.NewWithConfig(context.Background(), config) require.NoError(t, err) defer db.Close() @@ -283,7 +307,7 @@ func TestPoolAfterRelease(t *testing.T) { t.Parallel() func() { - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -306,7 +330,7 @@ func TestPoolAfterRelease(t *testing.T) { return afterReleaseCount%2 == 1 } - db, err := pgxpool.ConnectConfig(context.Background(), config) + db, err := pgxpool.NewWithConfig(context.Background(), config) require.NoError(t, err) defer db.Close() @@ -326,19 +350,11 @@ func TestPoolAfterRelease(t *testing.T) { func TestPoolAcquireAllIdle(t *testing.T) { t.Parallel() - db, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + db, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer db.Close() - conns := db.AcquireAllIdle(context.Background()) - assert.Len(t, conns, 1) - - for _, c := range conns { - c.Release() - } - waitForReleaseToComplete() - - conns = make([]*pgxpool.Conn, 3) + conns := make([]*pgxpool.Conn, 3) for i := range conns { conns[i], err = db.Acquire(context.Background()) assert.NoError(t, err) @@ -359,6 +375,31 @@ func TestPoolAcquireAllIdle(t *testing.T) { } } +func TestPoolReset(t *testing.T) { + t.Parallel() + + db, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer db.Close() + + conns := make([]*pgxpool.Conn, 3) + for i := range conns { + conns[i], err = db.Acquire(context.Background()) + assert.NoError(t, err) + } + + db.Reset() + + for _, c := range conns { + if c != nil { + c.Release() + } + } + waitForReleaseToComplete() + + require.EqualValues(t, 0, db.Stat().TotalConns()) +} + func TestConnReleaseChecksMaxConnLifetime(t *testing.T) { t.Parallel() @@ -367,7 +408,7 @@ func TestConnReleaseChecksMaxConnLifetime(t *testing.T) { config.MaxConnLifetime = 250 * time.Millisecond - db, err := pgxpool.ConnectConfig(context.Background(), config) + db, err := pgxpool.NewWithConfig(context.Background(), config) require.NoError(t, err) defer db.Close() @@ -386,7 +427,7 @@ func TestConnReleaseChecksMaxConnLifetime(t *testing.T) { func TestConnReleaseClosesBusyConn(t *testing.T) { t.Parallel() - db, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + db, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer db.Close() @@ -420,7 +461,7 @@ func TestPoolBackgroundChecksMaxConnLifetime(t *testing.T) { config.MaxConnLifetime = 100 * time.Millisecond config.HealthCheckPeriod = 100 * time.Millisecond - db, err := pgxpool.ConnectConfig(context.Background(), config) + db, err := pgxpool.NewWithConfig(context.Background(), config) require.NoError(t, err) defer db.Close() @@ -445,7 +486,7 @@ func TestPoolBackgroundChecksMaxConnIdleTime(t *testing.T) { config.MaxConnIdleTime = 100 * time.Millisecond config.HealthCheckPeriod = 150 * time.Millisecond - db, err := pgxpool.ConnectConfig(context.Background(), config) + db, err := pgxpool.NewWithConfig(context.Background(), config) require.NoError(t, err) defer db.Close() @@ -474,7 +515,7 @@ func TestPoolBackgroundChecksMinConns(t *testing.T) { config.HealthCheckPeriod = 100 * time.Millisecond config.MinConns = 2 - db, err := pgxpool.ConnectConfig(context.Background(), config) + db, err := pgxpool.NewWithConfig(context.Background(), config) require.NoError(t, err) defer db.Close() @@ -502,7 +543,7 @@ func TestPoolBackgroundChecksMinConns(t *testing.T) { func TestPoolExec(t *testing.T) { t.Parallel() - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -512,7 +553,7 @@ func TestPoolExec(t *testing.T) { func TestPoolQuery(t *testing.T) { t.Parallel() - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -541,7 +582,7 @@ func TestPoolQuery(t *testing.T) { func TestPoolQueryRow(t *testing.T) { t.Parallel() - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -557,7 +598,7 @@ func TestPoolQueryRow(t *testing.T) { func TestPoolQueryRowErrNoRows(t *testing.T) { t.Parallel() - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -568,7 +609,7 @@ func TestPoolQueryRowErrNoRows(t *testing.T) { func TestPoolSendBatch(t *testing.T) { t.Parallel() - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -588,7 +629,7 @@ func TestPoolCopyFrom(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - pool, err := pgxpool.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -601,7 +642,7 @@ func TestPoolCopyFrom(t *testing.T) { tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local) - inputRows := [][]interface{}{ + inputRows := [][]any{ {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime}, {nil, nil, nil, nil, nil, nil, nil}, } @@ -613,7 +654,7 @@ func TestPoolCopyFrom(t *testing.T) { rows, err := pool.Query(ctx, "select * from poolcopyfromtest") assert.NoError(t, err) - var outputRows [][]interface{} + var outputRows [][]any for rows.Next() { row, err := rows.Values() if err != nil { @@ -631,7 +672,7 @@ func TestConnReleaseClosesConnInFailedTransaction(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - pool, err := pgxpool.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -677,7 +718,7 @@ func TestConnReleaseClosesConnInTransaction(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - pool, err := pgxpool.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -718,7 +759,7 @@ func TestConnReleaseDestroysClosedConn(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - pool, err := pgxpool.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -747,7 +788,7 @@ func TestConnReleaseDestroysClosedConn(t *testing.T) { func TestConnPoolQueryConcurrentLoad(t *testing.T) { t.Parallel() - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -773,7 +814,7 @@ func TestConnReleaseWhenBeginFail(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - db, err := pgxpool.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + db, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer db.Close() @@ -797,7 +838,7 @@ func TestConnReleaseWhenBeginFail(t *testing.T) { } func TestTxBeginFuncNestedTransactionCommit(t *testing.T) { - db, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + db, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer db.Close() @@ -816,15 +857,15 @@ func TestTxBeginFuncNestedTransactionCommit(t *testing.T) { db.Exec(context.Background(), "drop table pgxpooltx") }() - err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { + err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error { _, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (1)") require.NoError(t, err) - err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { + err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error { _, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (2)") require.NoError(t, err) - err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { + err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error { _, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (3)") require.NoError(t, err) return nil @@ -844,7 +885,7 @@ func TestTxBeginFuncNestedTransactionCommit(t *testing.T) { } func TestTxBeginFuncNestedTransactionRollback(t *testing.T) { - db, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + db, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer db.Close() @@ -863,11 +904,11 @@ func TestTxBeginFuncNestedTransactionRollback(t *testing.T) { db.Exec(context.Background(), "drop table pgxpooltx") }() - err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { + err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error { _, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (1)") require.NoError(t, err) - err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { + err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error { _, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (2)") require.NoError(t, err) return errors.New("do a rollback") @@ -888,7 +929,7 @@ func TestTxBeginFuncNestedTransactionRollback(t *testing.T) { } func TestIdempotentPoolClose(t *testing.T) { - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) // Close the open pool. @@ -898,7 +939,7 @@ func TestIdempotentPoolClose(t *testing.T) { require.NotPanics(t, func() { pool.Close() }) } -func TestConnectCreatesMinPool(t *testing.T) { +func TestConnectEagerlyReachesMinPoolSize(t *testing.T) { t.Parallel() config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) @@ -906,7 +947,6 @@ func TestConnectCreatesMinPool(t *testing.T) { config.MinConns = int32(12) config.MaxConns = int32(15) - config.LazyConnect = false acquireAttempts := int64(0) connectAttempts := int64(0) @@ -920,166 +960,27 @@ func TestConnectCreatesMinPool(t *testing.T) { return nil } - pool, err := pgxpool.ConnectConfig(context.Background(), config) + pool, err := pgxpool.NewWithConfig(context.Background(), config) require.NoError(t, err) defer pool.Close() - stat := pool.Stat() - require.Equal(t, int32(12), stat.IdleConns()) - require.Equal(t, int64(1), stat.AcquireCount()) - require.Equal(t, int32(12), stat.TotalConns()) - require.Equal(t, int64(0), acquireAttempts) - require.Equal(t, int64(12), connectAttempts) -} -func TestConnectSkipMinPoolWithLazy(t *testing.T) { - t.Parallel() + for i := 0; i < 500; i++ { + time.Sleep(10 * time.Millisecond) - config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - - config.MinConns = int32(12) - config.MaxConns = int32(15) - config.LazyConnect = true - - acquireAttempts := int64(0) - connectAttempts := int64(0) - - config.BeforeAcquire = func(ctx context.Context, conn *pgx.Conn) bool { - atomic.AddInt64(&acquireAttempts, 1) - return true - } - config.BeforeConnect = func(ctx context.Context, cfg *pgx.ConnConfig) error { - atomic.AddInt64(&connectAttempts, 1) - return nil - } - - pool, err := pgxpool.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer pool.Close() - - stat := pool.Stat() - require.Equal(t, int32(0), stat.IdleConns()) - require.Equal(t, int64(0), stat.AcquireCount()) - require.Equal(t, int32(0), stat.TotalConns()) - require.Equal(t, int64(0), acquireAttempts) - require.Equal(t, int64(0), connectAttempts) -} - -func TestConnectMinPoolZero(t *testing.T) { - t.Parallel() - - config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - - config.MinConns = int32(0) - config.MaxConns = int32(15) - config.LazyConnect = false - - acquireAttempts := int64(0) - connectAttempts := int64(0) - - config.BeforeAcquire = func(ctx context.Context, conn *pgx.Conn) bool { - atomic.AddInt64(&acquireAttempts, 1) - return true - } - config.BeforeConnect = func(ctx context.Context, cfg *pgx.ConnConfig) error { - atomic.AddInt64(&connectAttempts, 1) - return nil - } - - pool, err := pgxpool.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer pool.Close() - - stat := pool.Stat() - require.Equal(t, int32(1), stat.IdleConns()) - require.Equal(t, int64(1), stat.AcquireCount()) - require.Equal(t, int32(1), stat.TotalConns()) - require.Equal(t, int64(0), acquireAttempts) - require.Equal(t, int64(1), connectAttempts) -} - -func TestCreateMinPoolClosesConnectionsOnError(t *testing.T) { - t.Parallel() - - config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - - config.MinConns = int32(12) - config.MaxConns = int32(15) - config.LazyConnect = false - - acquireAttempts := int64(0) - madeConnections := int64(0) - conns := make(chan *pgx.Conn, 15) - - config.BeforeAcquire = func(ctx context.Context, conn *pgx.Conn) bool { - atomic.AddInt64(&acquireAttempts, 1) - return true - } - config.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error { - conns <- conn - - atomic.AddInt64(&madeConnections, 1) - mc := atomic.LoadInt64(&madeConnections) - if mc == 10 { - return errors.New("mock error") + stat := pool.Stat() + if stat.IdleConns() == 12 && stat.AcquireCount() == 0 && stat.TotalConns() == 12 && atomic.LoadInt64(&acquireAttempts) == 0 && atomic.LoadInt64(&connectAttempts) == 12 { + return } - return nil - } - pool, err := pgxpool.ConnectConfig(context.Background(), config) - require.Error(t, err) - require.Nil(t, pool) - - close(conns) - for conn := range conns { - require.True(t, conn.IsClosed()) } - require.Equal(t, int64(0), acquireAttempts) - require.True(t, madeConnections >= 10, "Expected %d got %d", 10, madeConnections) -} + t.Fatal("did not reach min pool size") -func TestCreateMinPoolReturnsFirstError(t *testing.T) { - t.Parallel() - - config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - - config.MinConns = int32(12) - config.MaxConns = int32(15) - config.LazyConnect = false - - acquireAttempts := int64(0) - connectAttempts := int64(0) - - mockErr := errors.New("mock connect error") - - config.BeforeAcquire = func(ctx context.Context, conn *pgx.Conn) bool { - atomic.AddInt64(&acquireAttempts, 1) - return true - } - config.BeforeConnect = func(ctx context.Context, cfg *pgx.ConnConfig) error { - atomic.AddInt64(&connectAttempts, 1) - ca := atomic.LoadInt64(&connectAttempts) - if ca >= 5 { - return mockErr - } - return nil - } - - pool, err := pgxpool.ConnectConfig(context.Background(), config) - require.Nil(t, pool) - require.Error(t, err) - - require.True(t, connectAttempts >= 5, "Expected %d got %d", 5, connectAttempts) - require.ErrorIs(t, err, mockErr) } func TestPoolSendBatchBatchCloseTwice(t *testing.T) { t.Parallel() - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() diff --git a/pgxpool/rows.go b/pgxpool/rows.go index 6dc0cc34..2b11ecd3 100644 --- a/pgxpool/rows.go +++ b/pgxpool/rows.go @@ -1,29 +1,29 @@ package pgxpool import ( - "github.com/jackc/pgconn" - "github.com/jackc/pgproto3/v2" - "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" ) type errRows struct { err error } -func (errRows) Close() {} -func (e errRows) Err() error { return e.err } -func (errRows) CommandTag() pgconn.CommandTag { return nil } -func (errRows) FieldDescriptions() []pgproto3.FieldDescription { return nil } -func (errRows) Next() bool { return false } -func (e errRows) Scan(dest ...interface{}) error { return e.err } -func (e errRows) Values() ([]interface{}, error) { return nil, e.err } -func (e errRows) RawValues() [][]byte { return nil } +func (errRows) Close() {} +func (e errRows) Err() error { return e.err } +func (errRows) CommandTag() pgconn.CommandTag { return pgconn.CommandTag{} } +func (errRows) FieldDescriptions() []pgconn.FieldDescription { return nil } +func (errRows) Next() bool { return false } +func (e errRows) Scan(dest ...any) error { return e.err } +func (e errRows) Values() ([]any, error) { return nil, e.err } +func (e errRows) RawValues() [][]byte { return nil } +func (e errRows) Conn() *pgx.Conn { return nil } type errRow struct { err error } -func (e errRow) Scan(dest ...interface{}) error { return e.err } +func (e errRow) Scan(dest ...any) error { return e.err } type poolRows struct { r pgx.Rows @@ -50,7 +50,7 @@ func (rows *poolRows) CommandTag() pgconn.CommandTag { return rows.r.CommandTag() } -func (rows *poolRows) FieldDescriptions() []pgproto3.FieldDescription { +func (rows *poolRows) FieldDescriptions() []pgconn.FieldDescription { return rows.r.FieldDescriptions() } @@ -66,7 +66,7 @@ func (rows *poolRows) Next() bool { return n } -func (rows *poolRows) Scan(dest ...interface{}) error { +func (rows *poolRows) Scan(dest ...any) error { err := rows.r.Scan(dest...) if err != nil { rows.Close() @@ -74,7 +74,7 @@ func (rows *poolRows) Scan(dest ...interface{}) error { return err } -func (rows *poolRows) Values() ([]interface{}, error) { +func (rows *poolRows) Values() ([]any, error) { values, err := rows.r.Values() if err != nil { rows.Close() @@ -86,13 +86,17 @@ func (rows *poolRows) RawValues() [][]byte { return rows.r.RawValues() } +func (rows *poolRows) Conn() *pgx.Conn { + return rows.r.Conn() +} + type poolRow struct { r pgx.Row c *Conn err error } -func (row *poolRow) Scan(dest ...interface{}) error { +func (row *poolRow) Scan(dest ...any) error { if row.err != nil { return row.err } diff --git a/pgxpool/stat.go b/pgxpool/stat.go index 47342be4..cfa0c4c5 100644 --- a/pgxpool/stat.go +++ b/pgxpool/stat.go @@ -3,7 +3,7 @@ package pgxpool import ( "time" - "github.com/jackc/puddle" + "github.com/jackc/puddle/v2" ) // Stat is a snapshot of Pool statistics. diff --git a/pgxpool/tx.go b/pgxpool/tx.go index 6f566e41..74df8593 100644 --- a/pgxpool/tx.go +++ b/pgxpool/tx.go @@ -3,8 +3,8 @@ package pgxpool import ( "context" - "github.com/jackc/pgconn" - "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" ) // Tx represents a database transaction acquired from a Pool. @@ -18,10 +18,6 @@ func (tx *Tx) Begin(ctx context.Context) (pgx.Tx, error) { return tx.t.Begin(ctx) } -func (tx *Tx) BeginFunc(ctx context.Context, f func(pgx.Tx) error) error { - return tx.t.BeginFunc(ctx, f) -} - // Commit commits the transaction and returns the associated connection back to the Pool. Commit will return ErrTxClosed // if the Tx is already closed, but is otherwise safe to call multiple times. If the commit fails with a rollback status // (e.g. the transaction was already in a broken state) then ErrTxCommitRollback will be returned. @@ -69,22 +65,18 @@ func (tx *Tx) Prepare(ctx context.Context, name, sql string) (*pgconn.StatementD return tx.t.Prepare(ctx, name, sql) } -func (tx *Tx) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) { +func (tx *Tx) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) { return tx.t.Exec(ctx, sql, arguments...) } -func (tx *Tx) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) { +func (tx *Tx) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) { return tx.t.Query(ctx, sql, args...) } -func (tx *Tx) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row { +func (tx *Tx) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row { return tx.t.QueryRow(ctx, sql, args...) } -func (tx *Tx) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) { - return tx.t.QueryFunc(ctx, sql, args, scans, f) -} - func (tx *Tx) Conn() *pgx.Conn { return tx.t.Conn() } diff --git a/pgxpool/tx_test.go b/pgxpool/tx_test.go index d66ad338..8e140bf5 100644 --- a/pgxpool/tx_test.go +++ b/pgxpool/tx_test.go @@ -5,14 +5,14 @@ import ( "os" "testing" - "github.com/jackc/pgx/v4/pgxpool" + "github.com/jackc/pgx/v5/pgxpool" "github.com/stretchr/testify/require" ) func TestTxExec(t *testing.T) { t.Parallel() - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -26,7 +26,7 @@ func TestTxExec(t *testing.T) { func TestTxQuery(t *testing.T) { t.Parallel() - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -40,7 +40,7 @@ func TestTxQuery(t *testing.T) { func TestTxQueryRow(t *testing.T) { t.Parallel() - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -54,7 +54,7 @@ func TestTxQueryRow(t *testing.T) { func TestTxSendBatch(t *testing.T) { t.Parallel() - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -68,7 +68,7 @@ func TestTxSendBatch(t *testing.T) { func TestTxCopyFrom(t *testing.T) { t.Parallel() - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() diff --git a/pgxtest/pgxtest.go b/pgxtest/pgxtest.go new file mode 100644 index 00000000..796f850d --- /dev/null +++ b/pgxtest/pgxtest.go @@ -0,0 +1,153 @@ +// Package pgxtest provides utilities for testing pgx and packages that integrate with pgx. +package pgxtest + +import ( + "context" + "fmt" + "reflect" + "testing" + + "github.com/jackc/pgx/v5" +) + +var AllQueryExecModes = []pgx.QueryExecMode{ + pgx.QueryExecModeCacheStatement, + pgx.QueryExecModeCacheDescribe, + pgx.QueryExecModeDescribeExec, + pgx.QueryExecModeExec, + pgx.QueryExecModeSimpleProtocol, +} + +// KnownOIDQueryExecModes is a slice of all query exec modes where the param and result OIDs are known before sending the query. +var KnownOIDQueryExecModes = []pgx.QueryExecMode{ + pgx.QueryExecModeCacheStatement, + pgx.QueryExecModeCacheDescribe, + pgx.QueryExecModeDescribeExec, +} + +// ConnTestRunner controls how a *pgx.Conn is created and closed by tests. All fields are required. Use DefaultConnTestRunner to get a +// ConnTestRunner with reasonable default values. +type ConnTestRunner struct { + // CreateConfig returns a *pgx.ConnConfig suitable for use with pgx.ConnectConfig. + CreateConfig func(ctx context.Context, t testing.TB) *pgx.ConnConfig + + // AfterConnect is called after conn is established. It allows for arbitrary connection setup before a test begins. + AfterConnect func(ctx context.Context, t testing.TB, conn *pgx.Conn) + + // AfterTest is called after the test is run. It allows for validating the state of the connection before it is closed. + AfterTest func(ctx context.Context, t testing.TB, conn *pgx.Conn) + + // CloseConn closes conn. + CloseConn func(ctx context.Context, t testing.TB, conn *pgx.Conn) +} + +// DefaultConnTestRunner returns a new ConnTestRunner with all fields set to reasonable default values. +func DefaultConnTestRunner() ConnTestRunner { + return ConnTestRunner{ + CreateConfig: func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config, err := pgx.ParseConfig("") + if err != nil { + t.Fatalf("ParseConfig failed: %v", err) + } + return config + }, + AfterConnect: func(ctx context.Context, t testing.TB, conn *pgx.Conn) {}, + AfterTest: func(ctx context.Context, t testing.TB, conn *pgx.Conn) {}, + CloseConn: func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + err := conn.Close(ctx) + if err != nil { + t.Errorf("Close failed: %v", err) + } + }, + } +} + +func (ctr *ConnTestRunner) RunTest(ctx context.Context, t testing.TB, f func(ctx context.Context, t testing.TB, conn *pgx.Conn)) { + t.Helper() + + config := ctr.CreateConfig(ctx, t) + conn, err := pgx.ConnectConfig(ctx, config) + if err != nil { + t.Fatalf("ConnectConfig failed: %v", err) + } + defer ctr.CloseConn(ctx, t, conn) + + ctr.AfterConnect(ctx, t, conn) + f(ctx, t, conn) + ctr.AfterTest(ctx, t, conn) +} + +// RunWithQueryExecModes runs a f in a new test for each element of modes with a new connection created using connector. +// If modes is nil all pgx.QueryExecModes are tested. +func RunWithQueryExecModes(ctx context.Context, t *testing.T, ctr ConnTestRunner, modes []pgx.QueryExecMode, f func(ctx context.Context, t testing.TB, conn *pgx.Conn)) { + if modes == nil { + modes = AllQueryExecModes + } + + for _, mode := range modes { + ctrWithMode := ctr + ctrWithMode.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := ctr.CreateConfig(ctx, t) + config.DefaultQueryExecMode = mode + return config + } + + t.Run(mode.String(), + func(t *testing.T) { + ctrWithMode.RunTest(ctx, t, f) + }, + ) + } +} + +type ValueRoundTripTest struct { + Param any + Result any + Test func(any) bool +} + +func RunValueRoundTripTests( + ctx context.Context, + t testing.TB, + ctr ConnTestRunner, + modes []pgx.QueryExecMode, + pgTypeName string, + tests []ValueRoundTripTest, +) { + t.Helper() + + if modes == nil { + modes = AllQueryExecModes + } + + ctr.RunTest(ctx, t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + t.Helper() + + sql := fmt.Sprintf("select $1::%s", pgTypeName) + + for i, tt := range tests { + for _, mode := range modes { + err := conn.QueryRow(ctx, sql, mode, tt.Param).Scan(tt.Result) + if err != nil { + t.Errorf("%d. %v: %v", i, mode, err) + } + + result := reflect.ValueOf(tt.Result) + if result.Kind() == reflect.Ptr { + result = result.Elem() + } + + if !tt.Test(result.Interface()) { + t.Errorf("%d. %v: unexpected result for %v: %v", i, mode, tt.Param, result.Interface()) + } + } + } + }) +} + +// SkipCockroachDB calls Skip on t with msg if the connection is to a CockroachDB server. +func SkipCockroachDB(t testing.TB, conn *pgx.Conn, msg string) { + if conn.PgConn().ParameterStatus("crdb_version") != "" { + t.Skip(msg) + } +} diff --git a/pipeline_test.go b/pipeline_test.go new file mode 100644 index 00000000..b8590bf9 --- /dev/null +++ b/pipeline_test.go @@ -0,0 +1,79 @@ +package pgx_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/stretchr/testify/require" +) + +func TestPipelineWithoutPreparedOrDescribedStatements(t *testing.T) { + t.Parallel() + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pipeline := conn.PgConn().StartPipeline(ctx) + + eqb := pgx.ExtendedQueryBuilder{} + + err := eqb.Build(conn.TypeMap(), nil, []any{1, 2}) + require.NoError(t, err) + pipeline.SendQueryParams(`select $1::bigint + $2::bigint`, eqb.ParamValues, nil, eqb.ParamFormats, eqb.ResultFormats) + + err = eqb.Build(conn.TypeMap(), nil, []any{3, 4, 5}) + require.NoError(t, err) + pipeline.SendQueryParams(`select $1::bigint + $2::bigint + $3::bigint`, eqb.ParamValues, nil, eqb.ParamFormats, eqb.ResultFormats) + + err = pipeline.Sync() + require.NoError(t, err) + + results, err := pipeline.GetResults() + require.NoError(t, err) + rr, ok := results.(*pgconn.ResultReader) + require.True(t, ok) + rows := pgx.RowsFromResultReader(conn.TypeMap(), rr) + + rowCount := 0 + var n int64 + for rows.Next() { + err = rows.Scan(&n) + require.NoError(t, err) + rowCount++ + } + require.NoError(t, rows.Err()) + require.Equal(t, 1, rowCount) + require.Equal(t, "SELECT 1", rows.CommandTag().String()) + require.EqualValues(t, 3, n) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.True(t, ok) + rows = pgx.RowsFromResultReader(conn.TypeMap(), rr) + + rowCount = 0 + n = 0 + for rows.Next() { + err = rows.Scan(&n) + require.NoError(t, err) + rowCount++ + } + require.NoError(t, rows.Err()) + require.Equal(t, 1, rowCount) + require.Equal(t, "SELECT 1", rows.CommandTag().String()) + require.EqualValues(t, 12, n) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.True(t, ok) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + err = pipeline.Close() + require.NoError(t, err) + }) +} diff --git a/query_test.go b/query_test.go index 968c0ecc..b2aa5d10 100644 --- a/query_test.go +++ b/query_test.go @@ -7,19 +7,15 @@ import ( "errors" "fmt" "os" - "reflect" "strconv" "strings" "testing" "time" - "github.com/cockroachdb/apd" - "github.com/gofrs/uuid" - "github.com/jackc/pgconn" - "github.com/jackc/pgconn/stmtcache" - "github.com/jackc/pgtype" - "github.com/jackc/pgx/v4" - "github.com/shopspring/decimal" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -49,7 +45,7 @@ func TestConnQueryScan(t *testing.T) { t.Fatalf("conn.Query failed: %v", err) } - assert.Equal(t, "SELECT 10", string(rows.CommandTag())) + assert.Equal(t, "SELECT 10", rows.CommandTag().String()) if rowCount != 10 { t.Error("Select called onDataRow wrong number of times") @@ -70,7 +66,7 @@ func TestConnQueryRowsFieldDescriptionsBeforeNext(t *testing.T) { defer rows.Close() require.Len(t, rows.FieldDescriptions(), 1) - assert.Equal(t, []byte("msg"), rows.FieldDescriptions()[0].Name) + assert.Equal(t, "msg", rows.FieldDescriptions()[0].Name) } func TestConnQueryWithoutResultSetCommandTag(t *testing.T) { @@ -83,7 +79,7 @@ func TestConnQueryWithoutResultSetCommandTag(t *testing.T) { assert.NoError(t, err) rows.Close() assert.NoError(t, rows.Err()) - assert.Equal(t, "CREATE TABLE", string(rows.CommandTag())) + assert.Equal(t, "CREATE TABLE", rows.CommandTag().String()) } func TestConnQueryScanWithManyColumns(t *testing.T) { @@ -113,7 +109,7 @@ func TestConnQueryScanWithManyColumns(t *testing.T) { defer rows.Close() for rows.Next() { - destPtrs := make([]interface{}, columnCount) + destPtrs := make([]any, columnCount) for i := range destPtrs { destPtrs[i] = &dest[i] } @@ -249,7 +245,7 @@ func TestConnQueryReadRowMultipleTimes(t *testing.T) { var a, b string var c int32 - var d pgtype.Unknown + var d pgtype.Text var e int32 err = rows.Scan(&a, &b, &c, &d, &e) @@ -257,7 +253,7 @@ func TestConnQueryReadRowMultipleTimes(t *testing.T) { require.Equal(t, "foo", a) require.Equal(t, "bar", b) require.Equal(t, rowCount, c) - require.Equal(t, pgtype.Null, d.Status) + require.False(t, d.Valid) require.Equal(t, rowCount, e) } } @@ -266,68 +262,6 @@ func TestConnQueryReadRowMultipleTimes(t *testing.T) { require.Equal(t, int32(10), rowCount) } -// https://github.com/jackc/pgx/issues/386 -func TestConnQueryValuesWithMultipleComplexColumnsOfSameType(t *testing.T) { - t.Parallel() - - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) - - expected0 := &pgtype.Int8Array{ - Elements: []pgtype.Int8{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}}, - Status: pgtype.Present, - } - - expected1 := &pgtype.Int8Array{ - Elements: []pgtype.Int8{ - {Int: 4, Status: pgtype.Present}, - {Int: 5, Status: pgtype.Present}, - {Int: 6, Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}}, - Status: pgtype.Present, - } - - var rowCount int32 - - rows, err := conn.Query(context.Background(), "select '{1,2,3}'::bigint[], '{4,5,6}'::bigint[] from generate_series(1,$1) n", 10) - if err != nil { - t.Fatalf("conn.Query failed: %v", err) - } - defer rows.Close() - - for rows.Next() { - rowCount++ - - values, err := rows.Values() - if err != nil { - t.Fatalf("rows.Values failed: %v", err) - } - if len(values) != 2 { - t.Errorf("Expected rows.Values to return 2 values, but it returned %d", len(values)) - } - if !reflect.DeepEqual(values[0], *expected0) { - t.Errorf(`Expected values[0] to be %v, but it was %v`, *expected0, values[0]) - } - if !reflect.DeepEqual(values[1], *expected1) { - t.Errorf(`Expected values[1] to be %v, but it was %v`, *expected1, values[1]) - } - } - - if rows.Err() != nil { - t.Fatalf("conn.Query failed: %v", err) - } - - if rowCount != 10 { - t.Error("Select called onDataRow wrong number of times") - } -} - // https://github.com/jackc/pgx/issues/228 func TestRowsScanDoesNotAllowScanningBinaryFormatValuesIntoString(t *testing.T) { t.Parallel() @@ -335,11 +269,13 @@ func TestRowsScanDoesNotAllowScanningBinaryFormatValuesIntoString(t *testing.T) conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) + pgxtest.SkipCockroachDB(t, conn, "Server does not support point type") + var s string - err := conn.QueryRow(context.Background(), "select 1").Scan(&s) - if err == nil || !(strings.Contains(err.Error(), "cannot decode binary value into string") || strings.Contains(err.Error(), "cannot assign")) { - t.Fatalf("Expected Scan to fail to encode binary value into string but: %v", err) + err := conn.QueryRow(context.Background(), "select point(1,2)").Scan(&s) + if err == nil || !(strings.Contains(err.Error(), "cannot scan point (OID 600) in binary format into *string")) { + t.Fatalf("Expected Scan to fail to scan binary value into string but: %v", err) } ensureConnValid(t, conn) @@ -356,7 +292,7 @@ func TestConnQueryRawValues(t *testing.T) { rows, err := conn.Query( context.Background(), "select 'foo'::text, 'bar'::varchar, n, null, n from generate_series(1,$1) n", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, 10, ) require.NoError(t, err) @@ -440,7 +376,7 @@ func TestConnQueryReadWrongTypeError(t *testing.T) { defer closeConn(t, conn) // Read a single value incorrectly - rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 10) + rows, err := conn.Query(context.Background(), "select n::int4 from generate_series(1,$1) n", 10) if err != nil { t.Fatalf("conn.Query failed: %v", err) } @@ -461,7 +397,7 @@ func TestConnQueryReadWrongTypeError(t *testing.T) { t.Fatal("Expected Rows to have an error after an improper read but it didn't") } - if rows.Err().Error() != "can't scan into dest[0]: Can't convert OID 23 to time.Time" && !strings.Contains(rows.Err().Error(), "cannot assign") { + if rows.Err().Error() != "can't scan into dest[0]: cannot scan int4 (OID 23) in binary format into *time.Time" { t.Fatalf("Expected different Rows.Err(): %v", rows.Err()) } @@ -541,7 +477,7 @@ func TestConnQueryDeferredError(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - skipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") + pgxtest.SkipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") mustExec(t, conn, `create temporary table t ( id text primary key, @@ -583,7 +519,7 @@ func TestConnQueryErrorWhileReturningRows(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - skipCockroachDB(t, conn, "Server uses numeric instead of int") + pgxtest.SkipCockroachDB(t, conn, "Server uses numeric instead of int") for i := 0; i < 100; i++ { func() { @@ -661,18 +597,18 @@ func TestQueryRowCoreTypes(t *testing.T) { tests := []struct { sql string - queryArgs []interface{} - scanArgs []interface{} + queryArgs []any + scanArgs []any expected allTypes }{ - {"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.s}, allTypes{s: "Jack"}}, - {"select $1::float4", []interface{}{float32(1.23)}, []interface{}{&actual.f32}, allTypes{f32: 1.23}}, - {"select $1::float8", []interface{}{float64(1.23)}, []interface{}{&actual.f64}, allTypes{f64: 1.23}}, - {"select $1::bool", []interface{}{true}, []interface{}{&actual.b}, allTypes{b: true}}, - {"select $1::timestamptz", []interface{}{time.Unix(123, 5000)}, []interface{}{&actual.t}, allTypes{t: time.Unix(123, 5000)}}, - {"select $1::timestamp", []interface{}{time.Date(2010, 1, 2, 3, 4, 5, 0, time.UTC)}, []interface{}{&actual.t}, allTypes{t: time.Date(2010, 1, 2, 3, 4, 5, 0, time.UTC)}}, - {"select $1::date", []interface{}{time.Date(1987, 1, 2, 0, 0, 0, 0, time.UTC)}, []interface{}{&actual.t}, allTypes{t: time.Date(1987, 1, 2, 0, 0, 0, 0, time.UTC)}}, - {"select $1::oid", []interface{}{uint32(42)}, []interface{}{&actual.oid}, allTypes{oid: 42}}, + {"select $1::text", []any{"Jack"}, []any{&actual.s}, allTypes{s: "Jack"}}, + {"select $1::float4", []any{float32(1.23)}, []any{&actual.f32}, allTypes{f32: 1.23}}, + {"select $1::float8", []any{float64(1.23)}, []any{&actual.f64}, allTypes{f64: 1.23}}, + {"select $1::bool", []any{true}, []any{&actual.b}, allTypes{b: true}}, + {"select $1::timestamptz", []any{time.Unix(123, 5000)}, []any{&actual.t}, allTypes{t: time.Unix(123, 5000)}}, + {"select $1::timestamp", []any{time.Date(2010, 1, 2, 3, 4, 5, 0, time.UTC)}, []any{&actual.t}, allTypes{t: time.Date(2010, 1, 2, 3, 4, 5, 0, time.UTC)}}, + {"select $1::date", []any{time.Date(1987, 1, 2, 0, 0, 0, 0, time.UTC)}, []any{&actual.t}, allTypes{t: time.Date(1987, 1, 2, 0, 0, 0, 0, time.UTC)}}, + {"select $1::oid", []any{uint32(42)}, []any{&actual.oid}, allTypes{oid: 42}}, } for i, tt := range tests { @@ -722,8 +658,8 @@ func TestQueryRowCoreIntegerEncoding(t *testing.T) { successfulEncodeTests := []struct { sql string - queryArg interface{} - scanArg interface{} + queryArg any + scanArg any expected allTypes }{ // Check any integer type where value is within int2 range can be encoded @@ -781,7 +717,7 @@ func TestQueryRowCoreIntegerEncoding(t *testing.T) { failedEncodeTests := []struct { sql string - queryArg interface{} + queryArg any }{ // Check any integer type where value is outside pg:int2 range cannot be encoded {"select $1::int2", int(32769)}, @@ -837,7 +773,7 @@ func TestQueryRowCoreIntegerDecoding(t *testing.T) { successfulDecodeTests := []struct { sql string - scanArg interface{} + scanArg any expected allTypes }{ // Check any integer type where value is within Go:int range can be decoded @@ -923,65 +859,64 @@ func TestQueryRowCoreIntegerDecoding(t *testing.T) { } failedDecodeTests := []struct { - sql string - scanArg interface{} - expectedErr string + sql string + scanArg any }{ // Check any integer type where value is outside Go:int8 range cannot be decoded - {"select 128::int2", &actual.i8, "is greater than"}, - {"select 128::int4", &actual.i8, "is greater than"}, - {"select 128::int8", &actual.i8, "is greater than"}, - {"select -129::int2", &actual.i8, "is less than"}, - {"select -129::int4", &actual.i8, "is less than"}, - {"select -129::int8", &actual.i8, "is less than"}, + {"select 128::int2", &actual.i8}, + {"select 128::int4", &actual.i8}, + {"select 128::int8", &actual.i8}, + {"select -129::int2", &actual.i8}, + {"select -129::int4", &actual.i8}, + {"select -129::int8", &actual.i8}, // Check any integer type where value is outside Go:int16 range cannot be decoded - {"select 32768::int4", &actual.i16, "is greater than"}, - {"select 32768::int8", &actual.i16, "is greater than"}, - {"select -32769::int4", &actual.i16, "is less than"}, - {"select -32769::int8", &actual.i16, "is less than"}, + {"select 32768::int4", &actual.i16}, + {"select 32768::int8", &actual.i16}, + {"select -32769::int4", &actual.i16}, + {"select -32769::int8", &actual.i16}, // Check any integer type where value is outside Go:int32 range cannot be decoded - {"select 2147483648::int8", &actual.i32, "is greater than"}, - {"select -2147483649::int8", &actual.i32, "is less than"}, + {"select 2147483648::int8", &actual.i32}, + {"select -2147483649::int8", &actual.i32}, // Check any integer type where value is outside Go:uint range cannot be decoded - {"select -1::int2", &actual.ui, "is less than"}, - {"select -1::int4", &actual.ui, "is less than"}, - {"select -1::int8", &actual.ui, "is less than"}, + {"select -1::int2", &actual.ui}, + {"select -1::int4", &actual.ui}, + {"select -1::int8", &actual.ui}, // Check any integer type where value is outside Go:uint8 range cannot be decoded - {"select 256::int2", &actual.ui8, "is greater than"}, - {"select 256::int4", &actual.ui8, "is greater than"}, - {"select 256::int8", &actual.ui8, "is greater than"}, - {"select -1::int2", &actual.ui8, "is less than"}, - {"select -1::int4", &actual.ui8, "is less than"}, - {"select -1::int8", &actual.ui8, "is less than"}, + {"select 256::int2", &actual.ui8}, + {"select 256::int4", &actual.ui8}, + {"select 256::int8", &actual.ui8}, + {"select -1::int2", &actual.ui8}, + {"select -1::int4", &actual.ui8}, + {"select -1::int8", &actual.ui8}, // Check any integer type where value is outside Go:uint16 cannot be decoded - {"select 65536::int4", &actual.ui16, "is greater than"}, - {"select 65536::int8", &actual.ui16, "is greater than"}, - {"select -1::int2", &actual.ui16, "is less than"}, - {"select -1::int4", &actual.ui16, "is less than"}, - {"select -1::int8", &actual.ui16, "is less than"}, + {"select 65536::int4", &actual.ui16}, + {"select 65536::int8", &actual.ui16}, + {"select -1::int2", &actual.ui16}, + {"select -1::int4", &actual.ui16}, + {"select -1::int8", &actual.ui16}, // Check any integer type where value is outside Go:uint32 range cannot be decoded - {"select 4294967296::int8", &actual.ui32, "is greater than"}, - {"select -1::int2", &actual.ui32, "is less than"}, - {"select -1::int4", &actual.ui32, "is less than"}, - {"select -1::int8", &actual.ui32, "is less than"}, + {"select 4294967296::int8", &actual.ui32}, + {"select -1::int2", &actual.ui32}, + {"select -1::int4", &actual.ui32}, + {"select -1::int8", &actual.ui32}, // Check any integer type where value is outside Go:uint64 range cannot be decoded - {"select -1::int2", &actual.ui64, "is less than"}, - {"select -1::int4", &actual.ui64, "is less than"}, - {"select -1::int8", &actual.ui64, "is less than"}, + {"select -1::int2", &actual.ui64}, + {"select -1::int4", &actual.ui64}, + {"select -1::int8", &actual.ui64}, } for i, tt := range failedDecodeTests { err := conn.QueryRow(context.Background(), tt.sql).Scan(tt.scanArg) if err == nil { t.Errorf("%d. Expected failure to decode, but unexpectedly succeeded: %v (sql -> %v)", i, err, tt.sql) - } else if !strings.Contains(err.Error(), tt.expectedErr) { + } else if !strings.Contains(err.Error(), "can't scan") { t.Errorf("%d. Expected failure to decode, but got: %v (sql -> %v)", i, err, tt.sql) } @@ -997,7 +932,7 @@ func TestQueryRowCoreByteSlice(t *testing.T) { tests := []struct { sql string - queryArg interface{} + queryArg any expected []byte }{ {"select $1::text", "Jack", []byte("Jack")}, @@ -1028,6 +963,10 @@ func TestQueryRowErrors(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) + if conn.PgConn().ParameterStatus("crdb_version") != "" { + t.Skip("Skipping due to known server missing point type") + } + type allTypes struct { i16 int16 i int @@ -1038,14 +977,14 @@ func TestQueryRowErrors(t *testing.T) { tests := []struct { sql string - queryArgs []interface{} - scanArgs []interface{} + queryArgs []any + scanArgs []any err string }{ - // {"select $1::badtype", []interface{}{"Jack"}, []interface{}{&actual.i16}, `type "badtype" does not exist`}, - // {"SYNTAX ERROR", []interface{}{}, []interface{}{&actual.i16}, "SQLSTATE 42601"}, - {"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.i16}, "unable to assign"}, - // {"select $1::point", []interface{}{int(705)}, []interface{}{&actual.s}, "cannot convert 705 to Point"}, + {"select $1::badtype", []any{"Jack"}, []any{&actual.i16}, `type "badtype" does not exist`}, + {"SYNTAX ERROR", []any{}, []any{&actual.i16}, "SQLSTATE 42601"}, + {"select $1::text", []any{"Jack"}, []any{&actual.i16}, "cannot scan text (OID 25) in text format into *int16"}, + {"select $1::point", []any{int(705)}, []any{&actual.s}, "unable to encode 705 into binary format for point (OID 600)"}, } for i, tt := range tests { @@ -1155,55 +1094,52 @@ func TestReadingNullByteArrays(t *testing.T) { } } -// Use github.com/shopspring/decimal as real-world database/sql custom type -// to test against. +func TestQueryNullSliceIsSet(t *testing.T) { + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + a := []int32{1, 2, 3} + err := conn.QueryRow(context.Background(), "select null::int[]").Scan(&a) + if err != nil { + t.Fatalf("conn.QueryRow failed: %v", err) + } + + if a != nil { + t.Errorf("Expected 'a' to be nil, but it was: %v", a) + } +} + func TestConnQueryDatabaseSQLScanner(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - var num decimal.Decimal + var num sql.NullFloat64 - err := conn.QueryRow(context.Background(), "select '1234.567'::decimal").Scan(&num) + err := conn.QueryRow(context.Background(), "select '1234.567'::float8").Scan(&num) if err != nil { t.Fatalf("Scan failed: %v", err) } - expected, err := decimal.NewFromString("1234.567") - if err != nil { - t.Fatal(err) - } - - if !num.Equals(expected) { - t.Errorf("Expected num to be %v, but it was %v", expected, num) - } + require.True(t, num.Valid) + require.Equal(t, 1234.567, num.Float64) ensureConnValid(t, conn) } -// Use github.com/shopspring/decimal as real-world database/sql custom type -// to test against. func TestConnQueryDatabaseSQLDriverValuer(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - expected, err := decimal.NewFromString("1234.567") - if err != nil { - t.Fatal(err) - } - var num decimal.Decimal + expected := sql.NullFloat64{Float64: 1234.567, Valid: true} + var actual sql.NullFloat64 - err = conn.QueryRow(context.Background(), "select $1::decimal", &expected).Scan(&num) - if err != nil { - t.Fatalf("Scan failed: %v", err) - } - - if !num.Equals(expected) { - t.Errorf("Expected num to be %v, but it was %v", expected, num) - } + err := conn.QueryRow(context.Background(), "select $1::float8", &expected).Scan(&actual) + require.NoError(t, err) + require.Equal(t, expected, actual) ensureConnValid(t, conn) } @@ -1217,38 +1153,50 @@ func TestConnQueryDatabaseSQLDriverValuerWithAutoGeneratedPointerReceiver(t *tes mustExec(t, conn, "create temporary table t(n numeric)") - var d *apd.Decimal + var d *sql.NullInt64 commandTag, err := conn.Exec(context.Background(), `insert into t(n) values($1)`, d) if err != nil { t.Fatal(err) } - if string(commandTag) != "INSERT 0 1" { + if commandTag.String() != "INSERT 0 1" { t.Fatalf("want %s, got %s", "INSERT 0 1", commandTag) } ensureConnValid(t, conn) } -func TestConnQueryDatabaseSQLDriverValuerWithBinaryPgTypeThatAcceptsSameType(t *testing.T) { +func TestConnQueryDatabaseSQLDriverScannerWithBinaryPgTypeThatAcceptsSameType(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - expected, err := uuid.FromString("6ba7b810-9dad-11d1-80b4-00c04fd430c8") - if err != nil { - t.Fatal(err) - } + var actual sql.NullString + err := conn.QueryRow(context.Background(), "select '6ba7b810-9dad-11d1-80b4-00c04fd430c8'::uuid").Scan(&actual) + require.NoError(t, err) - var u2 uuid.UUID - err = conn.QueryRow(context.Background(), "select $1::uuid", expected).Scan(&u2) - if err != nil { - t.Fatalf("Scan failed: %v", err) - } + require.True(t, actual.Valid) + require.Equal(t, "6ba7b810-9dad-11d1-80b4-00c04fd430c8", actual.String) - if expected != u2 { - t.Errorf("Expected u2 to be %v, but it was %v", expected, u2) - } + ensureConnValid(t, conn) +} + +// https://github.com/jackc/pgx/issues/1273#issuecomment-1221672175 +func TestConnQueryDatabaseSQLDriverValuerTextWhenBinaryIsPreferred(t *testing.T) { + t.Parallel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + arg := sql.NullString{String: "1.234", Valid: true} + var result pgtype.Numeric + err := conn.QueryRow(context.Background(), "select $1::numeric", arg).Scan(&result) + require.NoError(t, err) + + require.True(t, result.Valid) + f64, err := result.Float64Value() + require.NoError(t, err) + require.Equal(t, pgtype.Float8{Float64: 1.234, Valid: true}, f64) ensureConnValid(t, conn) } @@ -1354,7 +1302,7 @@ func TestQueryContextErrorWhileReceivingRows(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - skipCockroachDB(t, conn, "Server uses numeric instead of int") + pgxtest.SkipCockroachDB(t, conn, "Server uses numeric instead of int") ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() @@ -1449,7 +1397,7 @@ func TestScanRow(t *testing.T) { for resultReader.NextRow() { var n int32 - err := pgx.ScanRow(conn.ConnInfo(), resultReader.FieldDescriptions(), resultReader.Values(), &n) + err := pgx.ScanRow(conn.TypeMap(), resultReader.FieldDescriptions(), resultReader.Values(), &n) assert.NoError(t, err) sum += n rowCount++ @@ -1476,7 +1424,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::int8", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, expected, ).Scan(&actual) if err != nil { @@ -1493,7 +1441,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::float8", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, expected, ).Scan(&actual) if err != nil { @@ -1509,8 +1457,8 @@ func TestConnSimpleProtocol(t *testing.T) { var actual bool err := conn.QueryRow( context.Background(), - "select $1", - pgx.QuerySimpleProtocol(true), + "select $1::boolean", + pgx.QueryExecModeSimpleProtocol, expected, ).Scan(&actual) if err != nil { @@ -1527,7 +1475,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::bytea", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, expected, ).Scan(&actual) if err != nil { @@ -1544,7 +1492,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::text", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, expected, ).Scan(&actual) if err != nil { @@ -1569,7 +1517,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::text[]", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) @@ -1590,7 +1538,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::smallint[]", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) @@ -1611,7 +1559,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::int[]", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) @@ -1632,7 +1580,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::bigint[]", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) @@ -1653,7 +1601,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::bigint[]", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) @@ -1674,7 +1622,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::smallint[]", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) @@ -1695,7 +1643,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::bigint[]", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) @@ -1716,7 +1664,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::bigint[]", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) @@ -1737,7 +1685,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::bigint[]", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) @@ -1758,7 +1706,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::float4[]", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) @@ -1779,7 +1727,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::float8[]", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) @@ -1792,12 +1740,12 @@ func TestConnSimpleProtocol(t *testing.T) { { if conn.PgConn().ParameterStatus("crdb_version") == "" { // CockroachDB doesn't support circle type. - expected := pgtype.Circle{P: pgtype.Vec2{1, 2}, R: 1.5, Status: pgtype.Present} + expected := pgtype.Circle{P: pgtype.Vec2{1, 2}, R: 1.5, Valid: true} actual := expected err := conn.QueryRow( context.Background(), "select $1::circle", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, &expected, ).Scan(&actual) if err != nil { @@ -1824,8 +1772,8 @@ func TestConnSimpleProtocol(t *testing.T) { var actualString string err := conn.QueryRow( context.Background(), - "select $1::int8, $2::float8, $3, $4::bytea, $5::text", - pgx.QuerySimpleProtocol(true), + "select $1::int8, $2::float8, $3::boolean, $4::bytea, $5::text", + pgx.QueryExecModeSimpleProtocol, expectedInt64, expectedFloat64, expectedBool, expectedBytes, expectedString, ).Scan(&actualInt64, &actualFloat64, &actualBool, &actualBytes, &actualString) if err != nil { @@ -1856,7 +1804,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, expected, ).Scan(&actual) if err != nil { @@ -1876,7 +1824,7 @@ func TestConnSimpleProtocolRefusesNonUTF8ClientEncoding(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - skipCockroachDB(t, conn, "Server does not support changing client_encoding (https://www.cockroachlabs.com/docs/stable/set-vars.html)") + pgxtest.SkipCockroachDB(t, conn, "Server does not support changing client_encoding (https://www.cockroachlabs.com/docs/stable/set-vars.html)") mustExec(t, conn, "set client_encoding to 'SQL_ASCII'") @@ -1884,7 +1832,7 @@ func TestConnSimpleProtocolRefusesNonUTF8ClientEncoding(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, "test", ).Scan(&expected) if err == nil { @@ -1900,7 +1848,7 @@ func TestConnSimpleProtocolRefusesNonStandardConformingStrings(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - skipCockroachDB(t, conn, "Server does not support standard_conforming_strings = off (https://github.com/cockroachdb/cockroach/issues/36215)") + pgxtest.SkipCockroachDB(t, conn, "Server does not support standard_conforming_strings = off (https://github.com/cockroachdb/cockroach/issues/36215)") mustExec(t, conn, "set standard_conforming_strings to off") @@ -1908,7 +1856,7 @@ func TestConnSimpleProtocolRefusesNonStandardConformingStrings(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, `\'; drop table users; --`, ).Scan(&expected) if err == nil { @@ -1918,63 +1866,14 @@ func TestConnSimpleProtocolRefusesNonStandardConformingStrings(t *testing.T) { ensureConnValid(t, conn) } -func TestQueryStatementCacheModes(t *testing.T) { - t.Parallel() - - config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) - - tests := []struct { - name string - buildStatementCache pgx.BuildStatementCacheFunc - }{ - { - name: "disabled", - buildStatementCache: nil, - }, - { - name: "prepare", - buildStatementCache: func(conn *pgconn.PgConn) stmtcache.Cache { - return stmtcache.New(conn, stmtcache.ModePrepare, 32) - }, - }, - { - name: "describe", - buildStatementCache: func(conn *pgconn.PgConn) stmtcache.Cache { - return stmtcache.New(conn, stmtcache.ModeDescribe, 32) - }, - }, - } - - for _, tt := range tests { - func() { - config.BuildStatementCache = tt.buildStatementCache - conn := mustConnect(t, config) - defer closeConn(t, conn) - - var n int - err := conn.QueryRow(context.Background(), "select 1").Scan(&n) - assert.NoError(t, err, tt.name) - assert.Equal(t, 1, n, tt.name) - - err = conn.QueryRow(context.Background(), "select 2").Scan(&n) - assert.NoError(t, err, tt.name) - assert.Equal(t, 2, n, tt.name) - - err = conn.QueryRow(context.Background(), "select 1").Scan(&n) - assert.NoError(t, err, tt.name) - assert.Equal(t, 1, n, tt.name) - - ensureConnValid(t, conn) - }() - } -} - // https://github.com/jackc/pgx/issues/895 -func TestQueryErrorWithNilStatementCacheMode(t *testing.T) { +func TestQueryErrorWithDisabledStatementCache(t *testing.T) { t.Parallel() config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) - config.BuildStatementCache = nil + config.DefaultQueryExecMode = pgx.QueryExecModeDescribeExec + config.StatementCacheCapacity = 0 + config.DescriptionCacheCapacity = 0 conn := mustConnect(t, config) defer closeConn(t, conn) @@ -2000,101 +1899,101 @@ func TestQueryErrorWithNilStatementCacheMode(t *testing.T) { ensureConnValid(t, conn) } -func TestConnQueryFunc(t *testing.T) { +func TestQueryWithQueryRewriter(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { - var actualResults []interface{} - - var a, b int - ct, err := conn.QueryFunc( - context.Background(), - "select n, n * 2 from generate_series(1, $1) n", - []interface{}{3}, - []interface{}{&a, &b}, - func(pgx.QueryFuncRow) error { - actualResults = append(actualResults, []interface{}{a, b}) - return nil - }, - ) + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + qr := testQueryRewriter{sql: "select $1::int", args: []any{42}} + rows, err := conn.Query(ctx, "should be replaced", &qr) require.NoError(t, err) - expectedResults := []interface{}{ - []interface{}{1, 2}, - []interface{}{2, 4}, - []interface{}{3, 6}, + var n int32 + var rowCount int + for rows.Next() { + rowCount++ + err = rows.Scan(&n) + require.NoError(t, err) } - require.Equal(t, expectedResults, actualResults) - require.EqualValues(t, 3, ct.RowsAffected()) + + require.NoError(t, rows.Err()) }) } -func TestConnQueryFuncScanError(t *testing.T) { - t.Parallel() +// This example uses Query without using any helpers to read the results. Normally CollectRows, ForEachRow, or another +// helper function should be used. +func ExampleConn_Query() { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { - var actualResults []interface{} - - var a, b int - ct, err := conn.QueryFunc( - context.Background(), - "select 'foo', 'bar' from generate_series(1, $1) n", - []interface{}{3}, - []interface{}{&a, &b}, - func(pgx.QueryFuncRow) error { - actualResults = append(actualResults, []interface{}{a, b}) - return nil - }, - ) - require.EqualError(t, err, "can't scan into dest[0]: unable to assign to *int") - require.Nil(t, ct) - }) -} - -func TestConnQueryFuncAbort(t *testing.T) { - t.Parallel() - - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { - var a, b int - ct, err := conn.QueryFunc( - context.Background(), - "select n, n * 2 from generate_series(1, $1) n", - []interface{}{3}, - []interface{}{&a, &b}, - func(pgx.QueryFuncRow) error { - return errors.New("abort") - }, - ) - require.EqualError(t, err, "abort") - require.Nil(t, ct) - }) -} - -func ExampleConn_QueryFunc() { - conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) if err != nil { fmt.Printf("Unable to establish connection: %v", err) return } - var a, b int - _, err = conn.QueryFunc( - context.Background(), - "select n, n * 2 from generate_series(1, $1) n", - []interface{}{3}, - []interface{}{&a, &b}, - func(pgx.QueryFuncRow) error { - fmt.Printf("%v, %v\n", a, b) - return nil - }, - ) + if conn.PgConn().ParameterStatus("crdb_version") != "" { + // Skip test / example when running on CockroachDB. Since an example can't be skipped fake success instead. + fmt.Println(`Cheeseburger: $10 +Fries: $5 +Soft Drink: $3`) + return + } + + // Setup example schema and data. + _, err = conn.Exec(ctx, ` +create temporary table products ( + id int primary key generated by default as identity, + name varchar(100) not null, + price int not null +); + +insert into products (name, price) values + ('Cheeseburger', 10), + ('Double Cheeseburger', 14), + ('Fries', 5), + ('Soft Drink', 3); +`) if err != nil { - fmt.Printf("QueryFunc error: %v", err) + fmt.Printf("Unable to setup example schema and data: %v", err) + return + } + + rows, err := conn.Query(ctx, "select name, price from products where price < $1 order by price desc", 12) + + // It is unnecessary to check err. If an error occurred it will be returned by rows.Err() later. But in rare + // cases it may be useful to detect the error as early as possible. + if err != nil { + fmt.Printf("Query error: %v", err) + return + } + + // Ensure rows is closed. It is safe to close rows multiple times. + defer rows.Close() + + // Iterate through the result set + for rows.Next() { + var name string + var price int32 + + err = rows.Scan(&name, &price) + if err != nil { + fmt.Printf("Scan error: %v", err) + return + } + + fmt.Printf("%s: $%d\n", name, price) + } + + // rows is closed automatically when rows.Next() returns false so it is not necessary to manually close rows. + + // The first error encountered by the original Query call, rows.Next or rows.Scan will be returned here. + if rows.Err() != nil { + fmt.Printf("rows error: %v", err) return } // Output: - // 1, 2 - // 2, 4 - // 3, 6 + // Cheeseburger: $10 + // Fries: $5 + // Soft Drink: $3 } diff --git a/rows.go b/rows.go index 4749ead9..33d8ab09 100644 --- a/rows.go +++ b/rows.go @@ -4,11 +4,12 @@ import ( "context" "errors" "fmt" + "reflect" "time" - "github.com/jackc/pgconn" - "github.com/jackc/pgproto3/v2" - "github.com/jackc/pgtype" + "github.com/jackc/pgx/v5/internal/stmtcache" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" ) // Rows is the result set returned from *Conn.Query. Rows must be closed before @@ -32,7 +33,7 @@ type Rows interface { // CommandTag returns the command tag from this query. It is only available after Rows is closed. CommandTag() pgconn.CommandTag - FieldDescriptions() []pgproto3.FieldDescription + FieldDescriptions() []pgconn.FieldDescription // Next prepares the next row for reading. It returns true if there is another // row and false if no more rows are available. It automatically closes rows @@ -43,16 +44,20 @@ type Rows interface { // dest can include pointers to core types, values implementing the Scanner // interface, and nil. nil will skip the value entirely. It is an error to // call Scan without first calling Next() and checking that it returned true. - Scan(dest ...interface{}) error + Scan(dest ...any) error // Values returns the decoded row values. As with Scan(), it is an error to // call Values without first calling Next() and checking that it returned // true. - Values() ([]interface{}, error) + Values() ([]any, error) - // RawValues returns the unparsed bytes of the row values. The returned [][]byte is only valid until the next Next - // call or the Rows is closed. However, the underlying byte data is safe to retain a reference to and mutate. + // RawValues returns the unparsed bytes of the row values. The returned data is only valid until the next Next + // call or the Rows is closed. RawValues() [][]byte + + // Conn returns the underlying *Conn on which the query was executed. This may return nil if Rows did not come from a + // *Conn (e.g. if it was created by RowsFromResultReader) + Conn() *Conn } // Row is a convenience wrapper over Rows that is returned by QueryRow. @@ -65,19 +70,32 @@ type Row interface { // Scan works the same as Rows. with the following exceptions. If no // rows were found it returns ErrNoRows. If multiple rows are returned it // ignores all but the first. - Scan(dest ...interface{}) error + Scan(dest ...any) error +} + +// RowScanner scans an entire row at a time into the RowScanner. +type RowScanner interface { + // ScanRows scans the row. + ScanRow(rows Rows) error } // connRow implements the Row interface for Conn.QueryRow. -type connRow connRows +type connRow baseRows -func (r *connRow) Scan(dest ...interface{}) (err error) { - rows := (*connRows)(r) +func (r *connRow) Scan(dest ...any) (err error) { + rows := (*baseRows)(r) if rows.Err() != nil { return rows.Err() } + for _, d := range dest { + if _, ok := d.(*pgtype.DriverBytes); ok { + rows.Close() + return fmt.Errorf("cannot scan into *pgtype.DriverBytes from QueryRow") + } + } + if !rows.Next() { if rows.Err() == nil { return ErrNoRows @@ -90,37 +108,37 @@ func (r *connRow) Scan(dest ...interface{}) (err error) { return rows.Err() } -type rowLog interface { - shouldLog(lvl LogLevel) bool - log(ctx context.Context, lvl LogLevel, msg string, data map[string]interface{}) -} +// baseRows implements the Rows interface for Conn.Query. +type baseRows struct { + typeMap *pgtype.Map + resultReader *pgconn.ResultReader + + values [][]byte -// connRows implements the Rows interface for Conn.Query. -type connRows struct { - ctx context.Context - logger rowLog - connInfo *pgtype.ConnInfo - values [][]byte - rowCount int - err error commandTag pgconn.CommandTag - startTime time.Time - sql string - args []interface{} + err error closed bool - conn *Conn - - resultReader *pgconn.ResultReader - multiResultReader *pgconn.MultiResultReader scanPlans []pgtype.ScanPlan + scanTypes []reflect.Type + + conn *Conn + multiResultReader *pgconn.MultiResultReader + + queryTracer QueryTracer + batchTracer BatchTracer + ctx context.Context + startTime time.Time + sql string + args []any + rowCount int } -func (rows *connRows) FieldDescriptions() []pgproto3.FieldDescription { +func (rows *baseRows) FieldDescriptions() []pgconn.FieldDescription { return rows.resultReader.FieldDescriptions() } -func (rows *connRows) Close() { +func (rows *baseRows) Close() { if rows.closed { return } @@ -142,35 +160,36 @@ func (rows *connRows) Close() { } } - if rows.logger != nil { - endTime := time.Now() + if rows.err != nil && rows.conn != nil && rows.sql != "" { + if stmtcache.IsStatementInvalid(rows.err) { + if sc := rows.conn.statementCache; sc != nil { + sc.Invalidate(rows.sql) + } - if rows.err == nil { - if rows.logger.shouldLog(LogLevelInfo) { - rows.logger.log(rows.ctx, LogLevelInfo, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args), "time": endTime.Sub(rows.startTime), "rowCount": rows.rowCount}) - } - } else { - if rows.logger.shouldLog(LogLevelError) { - rows.logger.log(rows.ctx, LogLevelError, "Query", map[string]interface{}{"err": rows.err, "sql": rows.sql, "time": endTime.Sub(rows.startTime), "args": logQueryArgs(rows.args)}) - } - if rows.err != nil && rows.conn.stmtcache != nil { - rows.conn.stmtcache.StatementErrored(rows.sql, rows.err) + if sc := rows.conn.descriptionCache; sc != nil { + sc.Invalidate(rows.sql) } } } + + if rows.batchTracer != nil { + rows.batchTracer.TraceBatchQuery(rows.ctx, rows.conn, TraceBatchQueryData{SQL: rows.sql, Args: rows.args, CommandTag: rows.commandTag, Err: rows.err}) + } else if rows.queryTracer != nil { + rows.queryTracer.TraceQueryEnd(rows.ctx, rows.conn, TraceQueryEndData{rows.commandTag, rows.err}) + } } -func (rows *connRows) CommandTag() pgconn.CommandTag { +func (rows *baseRows) CommandTag() pgconn.CommandTag { return rows.commandTag } -func (rows *connRows) Err() error { +func (rows *baseRows) Err() error { return rows.err } // fatal signals an error occurred after the query was sent to the server. It // closes the rows automatically. -func (rows *connRows) fatal(err error) { +func (rows *baseRows) fatal(err error) { if rows.err != nil { return } @@ -179,7 +198,7 @@ func (rows *connRows) fatal(err error) { rows.Close() } -func (rows *connRows) Next() bool { +func (rows *baseRows) Next() bool { if rows.closed { return false } @@ -194,8 +213,8 @@ func (rows *connRows) Next() bool { } } -func (rows *connRows) Scan(dest ...interface{}) error { - ci := rows.connInfo +func (rows *baseRows) Scan(dest ...any) error { + m := rows.typeMap fieldDescriptions := rows.FieldDescriptions() values := rows.values @@ -204,6 +223,13 @@ func (rows *connRows) Scan(dest ...interface{}) error { rows.fatal(err) return err } + + if len(dest) == 1 { + if rc, ok := dest[0].(RowScanner); ok { + return rc.ScanRow(rows) + } + } + if len(fieldDescriptions) != len(dest) { err := fmt.Errorf("number of field descriptions must equal number of destinations, got %d and %d", len(fieldDescriptions), len(dest)) rows.fatal(err) @@ -212,8 +238,10 @@ func (rows *connRows) Scan(dest ...interface{}) error { if rows.scanPlans == nil { rows.scanPlans = make([]pgtype.ScanPlan, len(values)) + rows.scanTypes = make([]reflect.Type, len(values)) for i := range dest { - rows.scanPlans[i] = ci.PlanScan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, dest[i]) + rows.scanPlans[i] = m.PlanScan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, dest[i]) + rows.scanTypes[i] = reflect.TypeOf(dest[i]) } } @@ -222,7 +250,12 @@ func (rows *connRows) Scan(dest ...interface{}) error { continue } - err := rows.scanPlans[i].Scan(ci, fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, values[i], dst) + if rows.scanTypes[i] != reflect.TypeOf(dst) { + rows.scanPlans[i] = m.PlanScan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, dest[i]) + rows.scanTypes[i] = reflect.TypeOf(dest[i]) + } + + err := rows.scanPlans[i].Scan(values[i], dst) if err != nil { err = ScanArgError{ColumnIndex: i, Err: err} rows.fatal(err) @@ -233,12 +266,12 @@ func (rows *connRows) Scan(dest ...interface{}) error { return nil } -func (rows *connRows) Values() ([]interface{}, error) { +func (rows *baseRows) Values() ([]any, error) { if rows.closed { return nil, errors.New("rows is closed") } - values := make([]interface{}, 0, len(rows.FieldDescriptions())) + values := make([]any, 0, len(rows.FieldDescriptions())) for i := range rows.FieldDescriptions() { buf := rows.values[i] @@ -249,49 +282,20 @@ func (rows *connRows) Values() ([]interface{}, error) { continue } - if dt, ok := rows.connInfo.DataTypeForOID(fd.DataTypeOID); ok { - value := dt.Value - - switch fd.Format { - case TextFormatCode: - decoder, ok := value.(pgtype.TextDecoder) - if !ok { - decoder = &pgtype.GenericText{} - } - err := decoder.DecodeText(rows.connInfo, buf) - if err != nil { - rows.fatal(err) - } - values = append(values, decoder.(pgtype.Value).Get()) - case BinaryFormatCode: - decoder, ok := value.(pgtype.BinaryDecoder) - if !ok { - decoder = &pgtype.GenericBinary{} - } - err := decoder.DecodeBinary(rows.connInfo, buf) - if err != nil { - rows.fatal(err) - } - values = append(values, value.Get()) - default: - rows.fatal(errors.New("Unknown format code")) + if dt, ok := rows.typeMap.TypeForOID(fd.DataTypeOID); ok { + value, err := dt.Codec.DecodeValue(rows.typeMap, fd.DataTypeOID, fd.Format, buf) + if err != nil { + rows.fatal(err) } + values = append(values, value) } else { switch fd.Format { case TextFormatCode: - decoder := &pgtype.GenericText{} - err := decoder.DecodeText(rows.connInfo, buf) - if err != nil { - rows.fatal(err) - } - values = append(values, decoder.Get()) + values = append(values, string(buf)) case BinaryFormatCode: - decoder := &pgtype.GenericBinary{} - err := decoder.DecodeBinary(rows.connInfo, buf) - if err != nil { - rows.fatal(err) - } - values = append(values, decoder.Get()) + newBuf := make([]byte, len(buf)) + copy(newBuf, buf) + values = append(values, newBuf) default: rows.fatal(errors.New("Unknown format code")) } @@ -305,10 +309,14 @@ func (rows *connRows) Values() ([]interface{}, error) { return values, rows.Err() } -func (rows *connRows) RawValues() [][]byte { +func (rows *baseRows) RawValues() [][]byte { return rows.values } +func (rows *baseRows) Conn() *Conn { + return rows.conn +} + type ScanArgError struct { ColumnIndex int Err error @@ -324,11 +332,11 @@ func (e ScanArgError) Unwrap() error { // ScanRow decodes raw row data into dest. It can be used to scan rows read from the lower level pgconn interface. // -// connInfo - OID to Go type mapping. +// typeMap - OID to Go type mapping. // fieldDescriptions - OID and format of values // values - the raw data as returned from the PostgreSQL server // dest - the destination that values will be decoded into -func ScanRow(connInfo *pgtype.ConnInfo, fieldDescriptions []pgproto3.FieldDescription, values [][]byte, dest ...interface{}) error { +func ScanRow(typeMap *pgtype.Map, fieldDescriptions []pgconn.FieldDescription, values [][]byte, dest ...any) error { if len(fieldDescriptions) != len(values) { return fmt.Errorf("number of field descriptions must equal number of values, got %d and %d", len(fieldDescriptions), len(values)) } @@ -341,7 +349,7 @@ func ScanRow(connInfo *pgtype.ConnInfo, fieldDescriptions []pgproto3.FieldDescri continue } - err := connInfo.Scan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, values[i], d) + err := typeMap.Scan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, values[i], d) if err != nil { return ScanArgError{ColumnIndex: i, Err: err} } @@ -349,3 +357,187 @@ func ScanRow(connInfo *pgtype.ConnInfo, fieldDescriptions []pgproto3.FieldDescri return nil } + +// RowsFromResultReader returns a Rows that will read from values resultReader and decode with typeMap. It can be used +// to read from the lower level pgconn interface. +func RowsFromResultReader(typeMap *pgtype.Map, resultReader *pgconn.ResultReader) Rows { + return &baseRows{ + typeMap: typeMap, + resultReader: resultReader, + } +} + +// ForEachRow iterates through rows. For each row it scans into the elements of scans and calls fn. If any row +// fails to scan or fn returns an error the query will be aborted and the error will be returned. Rows will be closed +// when ForEachRow returns. +func ForEachRow(rows Rows, scans []any, fn func() error) (pgconn.CommandTag, error) { + defer rows.Close() + + for rows.Next() { + err := rows.Scan(scans...) + if err != nil { + return pgconn.CommandTag{}, err + } + + err = fn() + if err != nil { + return pgconn.CommandTag{}, err + } + } + + if err := rows.Err(); err != nil { + return pgconn.CommandTag{}, err + } + + return rows.CommandTag(), nil +} + +// CollectableRow is the subset of Rows methods that a RowToFunc is allowed to call. +type CollectableRow interface { + FieldDescriptions() []pgconn.FieldDescription + Scan(dest ...any) error + Values() ([]any, error) + RawValues() [][]byte +} + +// RowToFunc is a function that scans or otherwise converts row to a T. +type RowToFunc[T any] func(row CollectableRow) (T, error) + +// CollectRows iterates through rows, calling fn for each row, and collecting the results into a slice of T. +func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) { + defer rows.Close() + + slice := []T{} + + for rows.Next() { + value, err := fn(rows) + if err != nil { + return nil, err + } + slice = append(slice, value) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return slice, nil +} + +// CollectOneRow calls fn for the first row in rows and returns the result. If no rows are found returns an error where errors.Is(ErrNoRows) is true. +// CollectOneRow is to CollectRows as QueryRow is to Query. +func CollectOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) { + defer rows.Close() + + var value T + var err error + + if !rows.Next() { + return value, ErrNoRows + } + + value, err = fn(rows) + if err != nil { + return value, err + } + + rows.Close() + return value, rows.Err() +} + +// RowTo returns a T scanned from row. +func RowTo[T any](row CollectableRow) (T, error) { + var value T + err := row.Scan(&value) + return value, err +} + +// RowTo returns a the address of a T scanned from row. +func RowToAddrOf[T any](row CollectableRow) (*T, error) { + var value T + err := row.Scan(&value) + return &value, err +} + +// RowToMap returns a map scanned from row. +func RowToMap(row CollectableRow) (map[string]any, error) { + var value map[string]any + err := row.Scan((*mapRowScanner)(&value)) + return value, err +} + +type mapRowScanner map[string]any + +func (rs *mapRowScanner) ScanRow(rows Rows) error { + values, err := rows.Values() + if err != nil { + return err + } + + *rs = make(mapRowScanner, len(values)) + + for i := range values { + (*rs)[string(rows.FieldDescriptions()[i].Name)] = values[i] + } + + return nil +} + +// RowToStructByPos returns a T scanned from row. T must be a struct. T must have the same number a public fields as row +// has fields. The row and T fields will by matched by position. +func RowToStructByPos[T any](row CollectableRow) (T, error) { + var value T + err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value}) + return value, err +} + +// RowToAddrOfStructByPos returns the address of a T scanned from row. T must be a struct. T must have the same number a +// public fields as row has fields. The row and T fields will by matched by position. +func RowToAddrOfStructByPos[T any](row CollectableRow) (*T, error) { + var value T + err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value}) + return &value, err +} + +type positionalStructRowScanner struct { + ptrToStruct any +} + +func (rs *positionalStructRowScanner) ScanRow(rows Rows) error { + dst := rs.ptrToStruct + dstValue := reflect.ValueOf(dst) + if dstValue.Kind() != reflect.Ptr { + return fmt.Errorf("dst not a pointer") + } + + dstElemValue := dstValue.Elem() + scanTargets := rs.appendScanTargets(dstElemValue, nil) + + if len(rows.RawValues()) > len(scanTargets) { + return fmt.Errorf("got %d values, but dst struct has only %d fields", len(rows.RawValues()), len(scanTargets)) + } + + return rows.Scan(scanTargets...) +} + +func (rs *positionalStructRowScanner) appendScanTargets(dstElemValue reflect.Value, scanTargets []any) []any { + dstElemType := dstElemValue.Type() + + if scanTargets == nil { + scanTargets = make([]any, 0, dstElemType.NumField()) + } + + for i := 0; i < dstElemType.NumField(); i++ { + sf := dstElemType.Field(i) + if sf.PkgPath == "" { + // Handle anoymous struct embedding, but do not try to handle embedded pointers. + if sf.Anonymous && sf.Type.Kind() == reflect.Struct { + scanTargets = append(scanTargets, rs.appendScanTargets(dstElemValue.Field(i), scanTargets)...) + } else { + scanTargets = append(scanTargets, dstElemValue.Field(i).Addr().Interface()) + } + } + } + + return scanTargets +} diff --git a/rows_test.go b/rows_test.go new file mode 100644 index 00000000..7aeafac8 --- /dev/null +++ b/rows_test.go @@ -0,0 +1,453 @@ +package pgx_test + +import ( + "context" + "errors" + "fmt" + "os" + "testing" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type testRowScanner struct { + name string + age int32 +} + +func (rs *testRowScanner) ScanRow(rows pgx.Rows) error { + return rows.Scan(&rs.name, &rs.age) +} + +func TestRowScanner(t *testing.T) { + t.Parallel() + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var s testRowScanner + err := conn.QueryRow(ctx, "select 'Adam' as name, 72 as height").Scan(&s) + require.NoError(t, err) + require.Equal(t, "Adam", s.name) + require.Equal(t, int32(72), s.age) + }) +} + +func TestForEachRow(t *testing.T) { + t.Parallel() + + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var actualResults []any + + rows, _ := conn.Query( + context.Background(), + "select n, n * 2 from generate_series(1, $1) n", + 3, + ) + var a, b int + ct, err := pgx.ForEachRow(rows, []any{&a, &b}, func() error { + actualResults = append(actualResults, []any{a, b}) + return nil + }) + require.NoError(t, err) + + expectedResults := []any{ + []any{1, 2}, + []any{2, 4}, + []any{3, 6}, + } + require.Equal(t, expectedResults, actualResults) + require.EqualValues(t, 3, ct.RowsAffected()) + }) +} + +func TestForEachRowScanError(t *testing.T) { + t.Parallel() + + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var actualResults []any + + rows, _ := conn.Query( + context.Background(), + "select 'foo', 'bar' from generate_series(1, $1) n", + 3, + ) + var a, b int + ct, err := pgx.ForEachRow(rows, []any{&a, &b}, func() error { + actualResults = append(actualResults, []any{a, b}) + return nil + }) + require.EqualError(t, err, "can't scan into dest[0]: cannot scan text (OID 25) in text format into *int") + require.Equal(t, pgconn.CommandTag{}, ct) + }) +} + +func TestForEachRowAbort(t *testing.T) { + t.Parallel() + + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query( + context.Background(), + "select n, n * 2 from generate_series(1, $1) n", + 3, + ) + var a, b int + ct, err := pgx.ForEachRow(rows, []any{&a, &b}, func() error { + return errors.New("abort") + }) + require.EqualError(t, err, "abort") + require.Equal(t, pgconn.CommandTag{}, ct) + }) +} + +func ExampleForEachRow() { + conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + fmt.Printf("Unable to establish connection: %v", err) + return + } + + rows, _ := conn.Query( + context.Background(), + "select n, n * 2 from generate_series(1, $1) n", + 3, + ) + var a, b int + _, err = pgx.ForEachRow(rows, []any{&a, &b}, func() error { + fmt.Printf("%v, %v\n", a, b) + return nil + }) + if err != nil { + fmt.Printf("ForEachRow error: %v", err) + return + } + + // Output: + // 1, 2 + // 2, 4 + // 3, 6 +} + +func TestCollectRows(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select n from generate_series(0, 99) n`) + numbers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (int32, error) { + var n int32 + err := row.Scan(&n) + return n, err + }) + require.NoError(t, err) + + assert.Len(t, numbers, 100) + for i := range numbers { + assert.Equal(t, int32(i), numbers[i]) + } + }) +} + +// This example uses CollectRows with a manually written collector function. In most cases RowTo, RowToAddrOf, +// RowToStructByPos, RowToAddrOfStructByPos, or another generic function would be used. +func ExampleCollectRows() { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + fmt.Printf("Unable to establish connection: %v", err) + return + } + + rows, _ := conn.Query(ctx, `select n from generate_series(1, 5) n`) + numbers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (int32, error) { + var n int32 + err := row.Scan(&n) + return n, err + }) + if err != nil { + fmt.Printf("CollectRows error: %v", err) + return + } + + fmt.Println(numbers) + + // Output: + // [1 2 3 4 5] +} + +func TestCollectOneRow(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 42`) + n, err := pgx.CollectOneRow(rows, func(row pgx.CollectableRow) (int32, error) { + var n int32 + err := row.Scan(&n) + return n, err + }) + assert.NoError(t, err) + assert.Equal(t, int32(42), n) + }) +} + +func TestCollectOneRowNotFound(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 42 where false`) + n, err := pgx.CollectOneRow(rows, func(row pgx.CollectableRow) (int32, error) { + var n int32 + err := row.Scan(&n) + return n, err + }) + assert.ErrorIs(t, err, pgx.ErrNoRows) + assert.Equal(t, int32(0), n) + }) +} + +func TestCollectOneRowIgnoresExtraRows(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select n from generate_series(42, 99) n`) + n, err := pgx.CollectOneRow(rows, func(row pgx.CollectableRow) (int32, error) { + var n int32 + err := row.Scan(&n) + return n, err + }) + require.NoError(t, err) + + assert.NoError(t, err) + assert.Equal(t, int32(42), n) + }) +} + +func TestRowTo(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select n from generate_series(0, 99) n`) + numbers, err := pgx.CollectRows(rows, pgx.RowTo[int32]) + require.NoError(t, err) + + assert.Len(t, numbers, 100) + for i := range numbers { + assert.Equal(t, int32(i), numbers[i]) + } + }) +} + +func ExampleRowTo() { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + fmt.Printf("Unable to establish connection: %v", err) + return + } + + rows, _ := conn.Query(ctx, `select n from generate_series(1, 5) n`) + numbers, err := pgx.CollectRows(rows, pgx.RowTo[int32]) + if err != nil { + fmt.Printf("CollectRows error: %v", err) + return + } + + fmt.Println(numbers) + + // Output: + // [1 2 3 4 5] +} + +func TestRowToAddrOf(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select n from generate_series(0, 99) n`) + numbers, err := pgx.CollectRows(rows, pgx.RowToAddrOf[int32]) + require.NoError(t, err) + + assert.Len(t, numbers, 100) + for i := range numbers { + assert.Equal(t, int32(i), *numbers[i]) + } + }) +} + +func ExampleRowToAddrOf() { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + fmt.Printf("Unable to establish connection: %v", err) + return + } + + rows, _ := conn.Query(ctx, `select n from generate_series(1, 5) n`) + pNumbers, err := pgx.CollectRows(rows, pgx.RowToAddrOf[int32]) + if err != nil { + fmt.Printf("CollectRows error: %v", err) + return + } + + for _, p := range pNumbers { + fmt.Println(*p) + } + + // Output: + // 1 + // 2 + // 3 + // 4 + // 5 +} + +func TestRowToMap(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 'Joe' as name, n as age from generate_series(0, 9) n`) + slice, err := pgx.CollectRows(rows, pgx.RowToMap) + require.NoError(t, err) + + assert.Len(t, slice, 10) + for i := range slice { + assert.Equal(t, "Joe", slice[i]["name"]) + assert.EqualValues(t, i, slice[i]["age"]) + } + }) +} + +func TestRowToStructByPos(t *testing.T) { + type person struct { + Name string + Age int32 + } + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 'Joe' as name, n as age from generate_series(0, 9) n`) + slice, err := pgx.CollectRows(rows, pgx.RowToStructByPos[person]) + require.NoError(t, err) + + assert.Len(t, slice, 10) + for i := range slice { + assert.Equal(t, "Joe", slice[i].Name) + assert.EqualValues(t, i, slice[i].Age) + } + }) +} + +func TestRowToStructByPosEmbeddedStruct(t *testing.T) { + type Name struct { + First string + Last string + } + + type person struct { + Name + Age int32 + } + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 'John' as first_name, 'Smith' as last_name, n as age from generate_series(0, 9) n`) + slice, err := pgx.CollectRows(rows, pgx.RowToStructByPos[person]) + require.NoError(t, err) + + assert.Len(t, slice, 10) + for i := range slice { + assert.Equal(t, "John", slice[i].Name.First) + assert.Equal(t, "Smith", slice[i].Name.Last) + assert.EqualValues(t, i, slice[i].Age) + } + }) +} + +// Pointer to struct is not supported. But check that we don't panic. +func TestRowToStructByPosEmbeddedPointerToStruct(t *testing.T) { + type Name struct { + First string + Last string + } + + type person struct { + *Name + Age int32 + } + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 'John' as first_name, 'Smith' as last_name, n as age from generate_series(0, 9) n`) + _, err := pgx.CollectRows(rows, pgx.RowToStructByPos[person]) + require.EqualError(t, err, "got 3 values, but dst struct has only 2 fields") + }) +} + +func ExampleRowToStructByPos() { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + fmt.Printf("Unable to establish connection: %v", err) + return + } + + if conn.PgConn().ParameterStatus("crdb_version") != "" { + // Skip test / example when running on CockroachDB. Since an example can't be skipped fake success instead. + fmt.Println(`Cheeseburger: $10 +Fries: $5 +Soft Drink: $3`) + return + } + + // Setup example schema and data. + _, err = conn.Exec(ctx, ` +create temporary table products ( + id int primary key generated by default as identity, + name varchar(100) not null, + price int not null +); + +insert into products (name, price) values + ('Cheeseburger', 10), + ('Double Cheeseburger', 14), + ('Fries', 5), + ('Soft Drink', 3); +`) + if err != nil { + fmt.Printf("Unable to setup example schema and data: %v", err) + return + } + + type product struct { + ID int32 + Name string + Price int32 + } + + rows, _ := conn.Query(ctx, "select * from products where price < $1 order by price desc", 12) + products, err := pgx.CollectRows(rows, pgx.RowToStructByPos[product]) + if err != nil { + fmt.Printf("CollectRows error: %v", err) + return + } + + for _, p := range products { + fmt.Printf("%s: $%d\n", p.Name, p.Price) + } + + // Output: + // Cheeseburger: $10 + // Fries: $5 + // Soft Drink: $3 +} + +func TestRowToAddrOfStructPos(t *testing.T) { + type person struct { + Name string + Age int32 + } + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 'Joe' as name, n as age from generate_series(0, 9) n`) + slice, err := pgx.CollectRows(rows, pgx.RowToAddrOfStructByPos[person]) + require.NoError(t, err) + + assert.Len(t, slice, 10) + for i := range slice { + assert.Equal(t, "Joe", slice[i].Name) + assert.EqualValues(t, i, slice[i].Age) + } + }) +} diff --git a/stdlib/sql.go b/stdlib/sql.go index da377ece..fc0b0239 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -2,50 +2,58 @@ // // A database/sql connection can be established through sql.Open. // -// db, err := sql.Open("pgx", "postgres://pgx_md5:secret@localhost:5432/pgx_test?sslmode=disable") -// if err != nil { -// return err -// } +// db, err := sql.Open("pgx", "postgres://pgx_md5:secret@localhost:5432/pgx_test?sslmode=disable") +// if err != nil { +// return err +// } // // Or from a DSN string. // -// db, err := sql.Open("pgx", "user=postgres password=secret host=localhost port=5432 database=pgx_test sslmode=disable") -// if err != nil { -// return err -// } +// db, err := sql.Open("pgx", "user=postgres password=secret host=localhost port=5432 database=pgx_test sslmode=disable") +// if err != nil { +// return err +// } // // Or a pgx.ConnConfig can be used to set configuration not accessible via connection string. In this case the // pgx.ConnConfig must first be registered with the driver. This registration returns a connection string which is used // with sql.Open. // -// connConfig, _ := pgx.ParseConfig(os.Getenv("DATABASE_URL")) -// connConfig.Logger = myLogger -// connStr := stdlib.RegisterConnConfig(connConfig) -// db, _ := sql.Open("pgx", connStr) +// connConfig, _ := pgx.ParseConfig(os.Getenv("DATABASE_URL")) +// connConfig.Logger = myLogger +// connStr := stdlib.RegisterConnConfig(connConfig) +// db, _ := sql.Open("pgx", connStr) // -// pgx uses standard PostgreSQL positional parameters in queries. e.g. $1, $2. -// It does not support named parameters. +// pgx uses standard PostgreSQL positional parameters in queries. e.g. $1, $2. It does not support named parameters. // -// db.QueryRow("select * from users where id=$1", userID) +// db.QueryRow("select * from users where id=$1", userID) // -// In Go 1.13 and above (*sql.Conn) Raw() can be used to get a *pgx.Conn from the standard -// database/sql.DB connection pool. This allows operations that use pgx specific functionality. +// (*sql.Conn) Raw() can be used to get a *pgx.Conn from the standard database/sql.DB connection pool. This allows +// operations that use pgx specific functionality. // -// // Given db is a *sql.DB -// conn, err := db.Conn(context.Background()) -// if err != nil { -// // handle error from acquiring connection from DB pool -// } +// // Given db is a *sql.DB +// conn, err := db.Conn(context.Background()) +// if err != nil { +// // handle error from acquiring connection from DB pool +// } // -// err = conn.Raw(func(driverConn interface{}) error { -// conn := driverConn.(*stdlib.Conn).Conn() // conn is a *pgx.Conn -// // Do pgx specific stuff with conn -// conn.CopyFrom(...) -// return nil -// }) -// if err != nil { -// // handle error that occurred while using *pgx.Conn -// } +// err = conn.Raw(func(driverConn any) error { +// conn := driverConn.(*stdlib.Conn).Conn() // conn is a *pgx.Conn +// // Do pgx specific stuff with conn +// conn.CopyFrom(...) +// return nil +// }) +// if err != nil { +// // handle error that occurred while using *pgx.Conn +// } +// +// PostgreSQL Specific Data Types +// +// The pgtype package provides support for PostgreSQL specific types. *pgtype.Map.SQLScanner is an adapter that makes +// these types usable as a sql.Scanner. +// +// m := pgtype.NewMap() +// var a []int64 +// err := db.QueryRow("select '{1,2,3}'::bigint[]").Scan(m.SQLScanner(&a)) package stdlib import ( @@ -63,9 +71,9 @@ import ( "sync" "time" - "github.com/jackc/pgconn" - "github.com/jackc/pgtype" - "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" ) // Only intrinsic types should be binary format with database/sql. @@ -73,17 +81,10 @@ var databaseSQLResultFormats pgx.QueryResultFormatsByOID var pgxDriver *Driver -type ctxKey int - -var ctxKeyFakeTx ctxKey = 0 - -var ErrNotPgx = errors.New("not pgx *sql.DB") - func init() { pgxDriver = &Driver{ configs: make(map[string]*pgx.ConnConfig), } - fakeTxConns = make(map[*pgx.Conn]*sql.Tx) sql.Register("pgx", pgxDriver) databaseSQLResultFormats = pgx.QueryResultFormatsByOID{ @@ -103,11 +104,6 @@ func init() { } } -var ( - fakeTxMutex sync.Mutex - fakeTxConns map[*pgx.Conn]*sql.Tx -) - // OptionOpenDB options for configuring the driver when opening a new db pool. type OptionOpenDB func(*connector) @@ -312,11 +308,12 @@ func UnregisterConnConfig(connStr string) { } type Conn struct { - conn *pgx.Conn - psCount int64 // Counter used for creating unique prepared statement names - driver *Driver - connConfig pgx.ConnConfig - resetSessionFunc func(context.Context, *pgx.Conn) error // Function is called before a connection is reused + conn *pgx.Conn + psCount int64 // Counter used for creating unique prepared statement names + driver *Driver + connConfig pgx.ConnConfig + resetSessionFunc func(context.Context, *pgx.Conn) error // Function is called before a connection is reused + lastResetSessionTime time.Time } // Conn returns the underlying *pgx.Conn @@ -359,11 +356,6 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e return nil, driver.ErrBadConn } - if pconn, ok := ctx.Value(ctxKeyFakeTx).(**pgx.Conn); ok { - *pconn = c.conn - return fakeTx{}, nil - } - var pgxOpts pgx.TxOptions switch sql.IsolationLevel(opts.Isolation) { case sql.LevelDefault: @@ -413,7 +405,7 @@ func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.Na return nil, driver.ErrBadConn } - args := []interface{}{databaseSQLResultFormats} + args := []any{databaseSQLResultFormats} args = append(args, namedValueToInterface(argsV)...) rows, err := c.conn.Query(ctx, query, args...) @@ -459,6 +451,14 @@ func (c *Conn) ResetSession(ctx context.Context) error { return driver.ErrBadConn } + now := time.Now() + if now.Sub(c.lastResetSessionTime) > time.Second { + if err := c.conn.PgConn().CheckConn(); err != nil { + return driver.ErrBadConn + } + } + c.lastResetSessionTime = now + return c.resetSessionFunc(ctx, c.conn) } @@ -519,7 +519,7 @@ func (r *Rows) Columns() []string { // ColumnTypeDatabaseTypeName returns the database system type name. If the name is unknown the OID is returned. func (r *Rows) ColumnTypeDatabaseTypeName(index int) string { - if dt, ok := r.conn.conn.ConnInfo().DataTypeForOID(r.rows.FieldDescriptions()[index].DataTypeOID); ok { + if dt, ok := r.conn.conn.TypeMap().TypeForOID(r.rows.FieldDescriptions()[index].DataTypeOID); ok { return strings.ToUpper(dt.Name) } @@ -594,7 +594,7 @@ func (r *Rows) Close() error { } func (r *Rows) Next(dest []driver.Value) error { - ci := r.conn.conn.ConnInfo() + m := r.conn.conn.TypeMap() fieldDescriptions := r.rows.FieldDescriptions() if r.valueFuncs == nil { @@ -607,23 +607,23 @@ func (r *Rows) Next(dest []driver.Value) error { switch fd.DataTypeOID { case pgtype.BoolOID: var d bool - scanPlan := ci.PlanScan(dataTypeOID, format, &d) + scanPlan := m.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) + err := scanPlan.Scan(src, &d) return d, err } case pgtype.ByteaOID: var d []byte - scanPlan := ci.PlanScan(dataTypeOID, format, &d) + scanPlan := m.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) + err := scanPlan.Scan(src, &d) return d, err } - case pgtype.CIDOID: - var d pgtype.CID - scanPlan := ci.PlanScan(dataTypeOID, format, &d) + case pgtype.CIDOID, pgtype.OIDOID, pgtype.XIDOID: + var d pgtype.Uint32 + scanPlan := m.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) + err := scanPlan.Scan(src, &d) if err != nil { return nil, err } @@ -631,9 +631,9 @@ func (r *Rows) Next(dest []driver.Value) error { } case pgtype.DateOID: var d pgtype.Date - scanPlan := ci.PlanScan(dataTypeOID, format, &d) + scanPlan := m.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) + err := scanPlan.Scan(src, &d) if err != nil { return nil, err } @@ -641,74 +641,54 @@ func (r *Rows) Next(dest []driver.Value) error { } case pgtype.Float4OID: var d float32 - scanPlan := ci.PlanScan(dataTypeOID, format, &d) + scanPlan := m.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) + err := scanPlan.Scan(src, &d) return float64(d), err } case pgtype.Float8OID: var d float64 - scanPlan := ci.PlanScan(dataTypeOID, format, &d) + scanPlan := m.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) + err := scanPlan.Scan(src, &d) return d, err } case pgtype.Int2OID: var d int16 - scanPlan := ci.PlanScan(dataTypeOID, format, &d) + scanPlan := m.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) + err := scanPlan.Scan(src, &d) return int64(d), err } case pgtype.Int4OID: var d int32 - scanPlan := ci.PlanScan(dataTypeOID, format, &d) + scanPlan := m.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) + err := scanPlan.Scan(src, &d) return int64(d), err } case pgtype.Int8OID: var d int64 - scanPlan := ci.PlanScan(dataTypeOID, format, &d) + scanPlan := m.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) + err := scanPlan.Scan(src, &d) return d, err } - case pgtype.JSONOID: - var d pgtype.JSON - scanPlan := ci.PlanScan(dataTypeOID, format, &d) + case pgtype.JSONOID, pgtype.JSONBOID: + var d []byte + scanPlan := m.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) + err := scanPlan.Scan(src, &d) if err != nil { return nil, err } - return d.Value() - } - case pgtype.JSONBOID: - var d pgtype.JSONB - scanPlan := ci.PlanScan(dataTypeOID, format, &d) - r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) - if err != nil { - return nil, err - } - return d.Value() - } - case pgtype.OIDOID: - var d pgtype.OIDValue - scanPlan := ci.PlanScan(dataTypeOID, format, &d) - r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) - if err != nil { - return nil, err - } - return d.Value() + return d, nil } case pgtype.TimestampOID: var d pgtype.Timestamp - scanPlan := ci.PlanScan(dataTypeOID, format, &d) + scanPlan := m.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) + err := scanPlan.Scan(src, &d) if err != nil { return nil, err } @@ -716,19 +696,9 @@ func (r *Rows) Next(dest []driver.Value) error { } case pgtype.TimestamptzOID: var d pgtype.Timestamptz - scanPlan := ci.PlanScan(dataTypeOID, format, &d) + scanPlan := m.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) - if err != nil { - return nil, err - } - return d.Value() - } - case pgtype.XIDOID: - var d pgtype.XID - scanPlan := ci.PlanScan(dataTypeOID, format, &d) - r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) + err := scanPlan.Scan(src, &d) if err != nil { return nil, err } @@ -736,9 +706,9 @@ func (r *Rows) Next(dest []driver.Value) error { } default: var d string - scanPlan := ci.PlanScan(dataTypeOID, format, &d) + scanPlan := m.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) + err := scanPlan.Scan(src, &d) return d, err } } @@ -776,11 +746,11 @@ func (r *Rows) Next(dest []driver.Value) error { return nil } -func valueToInterface(argsV []driver.Value) []interface{} { - args := make([]interface{}, 0, len(argsV)) +func valueToInterface(argsV []driver.Value) []any { + args := make([]any, 0, len(argsV)) for _, v := range argsV { if v != nil { - args = append(args, v.(interface{})) + args = append(args, v.(any)) } else { args = append(args, nil) } @@ -788,11 +758,11 @@ func valueToInterface(argsV []driver.Value) []interface{} { return args } -func namedValueToInterface(argsV []driver.NamedValue) []interface{} { - args := make([]interface{}, 0, len(argsV)) +func namedValueToInterface(argsV []driver.NamedValue) []any { + args := make([]any, 0, len(argsV)) for _, v := range argsV { if v.Value != nil { - args = append(args, v.Value.(interface{})) + args = append(args, v.Value.(any)) } else { args = append(args, nil) } @@ -808,55 +778,3 @@ type wrapTx struct { func (wtx wrapTx) Commit() error { return wtx.tx.Commit(wtx.ctx) } func (wtx wrapTx) Rollback() error { return wtx.tx.Rollback(wtx.ctx) } - -type fakeTx struct{} - -func (fakeTx) Commit() error { return nil } - -func (fakeTx) Rollback() error { return nil } - -// AcquireConn acquires a *pgx.Conn from database/sql connection pool. It must be released with ReleaseConn. -// -// In Go 1.13 this functionality has been incorporated into the standard library in the db.Conn.Raw() method. -func AcquireConn(db *sql.DB) (*pgx.Conn, error) { - var conn *pgx.Conn - ctx := context.WithValue(context.Background(), ctxKeyFakeTx, &conn) - tx, err := db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - if conn == nil { - tx.Rollback() - return nil, ErrNotPgx - } - - fakeTxMutex.Lock() - fakeTxConns[conn] = tx - fakeTxMutex.Unlock() - - return conn, nil -} - -// ReleaseConn releases a *pgx.Conn acquired with AcquireConn. -func ReleaseConn(db *sql.DB, conn *pgx.Conn) error { - var tx *sql.Tx - var ok bool - - if conn.PgConn().IsBusy() || conn.PgConn().TxStatus() != 'I' { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - conn.Close(ctx) - } - - fakeTxMutex.Lock() - tx, ok = fakeTxConns[conn] - if ok { - delete(fakeTxConns, conn) - fakeTxMutex.Unlock() - } else { - fakeTxMutex.Unlock() - return fmt.Errorf("can't release conn that is not acquired") - } - - return tx.Rollback() -} diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 099320c0..ca2dccf3 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -9,13 +9,15 @@ import ( "os" "reflect" "regexp" + "strconv" "testing" "time" - "github.com/Masterminds/semver/v3" - "github.com/jackc/pgconn" - "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/stdlib" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/stdlib" + "github.com/jackc/pgx/v5/tracelog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -36,7 +38,7 @@ func skipCockroachDB(t testing.TB, db *sql.DB, msg string) { require.NoError(t, err) defer conn.Close() - err = conn.Raw(func(driverConn interface{}) error { + err = conn.Raw(func(driverConn any) error { conn := driverConn.(*stdlib.Conn).Conn() if conn.PgConn().ParameterStatus("crdb_version") != "" { t.Skip(msg) @@ -46,73 +48,60 @@ func skipCockroachDB(t testing.TB, db *sql.DB, msg string) { require.NoError(t, err) } -func skipPostgreSQLVersion(t testing.TB, db *sql.DB, constraintStr, msg string) { +func skipPostgreSQLVersionLessThan(t testing.TB, db *sql.DB, minVersion int64) { conn, err := db.Conn(context.Background()) require.NoError(t, err) defer conn.Close() - err = conn.Raw(func(driverConn interface{}) error { + err = conn.Raw(func(driverConn any) error { conn := driverConn.(*stdlib.Conn).Conn() serverVersionStr := conn.PgConn().ParameterStatus("server_version") - serverVersionStr = regexp.MustCompile(`^[0-9.]+`).FindString(serverVersionStr) + serverVersionStr = regexp.MustCompile(`^[0-9]+`).FindString(serverVersionStr) // if not PostgreSQL do nothing if serverVersionStr == "" { return nil } - serverVersion, err := semver.NewVersion(serverVersionStr) + serverVersion, err := strconv.ParseInt(serverVersionStr, 10, 64) if err != nil { return err } - c, err := semver.NewConstraint(constraintStr) - if err != nil { - return err + if serverVersion < minVersion { + t.Skipf("Test requires PostgreSQL v%d+", minVersion) } - if c.Check(serverVersion) { - t.Skip(msg) - } return nil }) require.NoError(t, err) } -func testWithAndWithoutPreferSimpleProtocol(t *testing.T, f func(t *testing.T, db *sql.DB)) { - t.Run("SimpleProto", - func(t *testing.T) { - config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - - config.PreferSimpleProtocol = true - db := stdlib.OpenDB(*config) - defer func() { - err := db.Close() +func testWithAllQueryExecModes(t *testing.T, f func(t *testing.T, db *sql.DB)) { + for _, mode := range []pgx.QueryExecMode{ + pgx.QueryExecModeCacheStatement, + pgx.QueryExecModeCacheDescribe, + pgx.QueryExecModeDescribeExec, + pgx.QueryExecModeExec, + pgx.QueryExecModeSimpleProtocol, + } { + t.Run(mode.String(), + func(t *testing.T) { + config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) - }() - f(t, db) + config.DefaultQueryExecMode = mode + db := stdlib.OpenDB(*config) + defer func() { + err := db.Close() + require.NoError(t, err) + }() - ensureDBValid(t, db) - }, - ) + f(t, db) - t.Run("DefaultProto", - func(t *testing.T) { - config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - - db := stdlib.OpenDB(*config) - defer func() { - err := db.Close() - require.NoError(t, err) - }() - - f(t, db) - - ensureDBValid(t, db) - }, - ) + ensureDBValid(t, db) + }, + ) + } } // Do a simple query to ensure the DB is still usable. This is of less use in stdlib as the connection pool should @@ -271,7 +260,7 @@ func TestQueryCloseRowsEarly(t *testing.T) { } func TestConnExec(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { _, err := db.Exec("create temporary table t(a varchar not null)") require.NoError(t, err) @@ -285,7 +274,7 @@ func TestConnExec(t *testing.T) { } func TestConnQuery(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") rows, err := db.Query("select 'foo', n from generate_series($1::int, $2::int) n", int32(1), int32(10)) @@ -317,7 +306,7 @@ func TestConnQuery(t *testing.T) { // https://github.com/jackc/pgx/issues/781 func TestConnQueryDifferentScanPlansIssue781(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { var s string var b bool @@ -332,7 +321,7 @@ func TestConnQueryDifferentScanPlansIssue781(t *testing.T) { } func TestConnQueryNull(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { rows, err := db.Query("select $1::int", nil) require.NoError(t, err) @@ -357,7 +346,7 @@ func TestConnQueryNull(t *testing.T) { } func TestConnQueryRowByteSlice(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { expected := []byte{222, 173, 190, 239} var actual []byte @@ -368,7 +357,7 @@ func TestConnQueryRowByteSlice(t *testing.T) { } func TestConnQueryFailure(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { _, err := db.Query("select 'foo") require.Error(t, err) require.IsType(t, new(pgconn.PgError), err) @@ -376,7 +365,7 @@ func TestConnQueryFailure(t *testing.T) { } func TestConnSimpleSlicePassThrough(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { skipCockroachDB(t, db, "Server does not support cardinality function") var n int64 @@ -386,10 +375,54 @@ func TestConnSimpleSlicePassThrough(t *testing.T) { }) } +func TestConnQueryScanGoArray(t *testing.T) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + m := pgtype.NewMap() + + var a []int64 + err := db.QueryRow("select '{1,2,3}'::bigint[]").Scan(m.SQLScanner(&a)) + require.NoError(t, err) + assert.Equal(t, []int64{1, 2, 3}, a) + }) +} + +func TestConnQueryScanArray(t *testing.T) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + m := pgtype.NewMap() + + var a pgtype.Array[int64] + err := db.QueryRow("select '{1,2,3}'::bigint[]").Scan(m.SQLScanner(&a)) + require.NoError(t, err) + assert.Equal(t, pgtype.Array[int64]{Elements: []int64{1, 2, 3}, Dims: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}}, Valid: true}, a) + }) +} + +func TestConnQueryScanRange(t *testing.T) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + skipCockroachDB(t, db, "Server does not support int4range") + + m := pgtype.NewMap() + + var r pgtype.Range[pgtype.Int4] + err := db.QueryRow("select int4range(1, 5)").Scan(m.SQLScanner(&r)) + require.NoError(t, err) + assert.Equal( + t, + pgtype.Range[pgtype.Int4]{ + Lower: pgtype.Int4{Int32: 1, Valid: true}, + Upper: pgtype.Int4{Int32: 5, Valid: true}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + r) + }) +} + // Test type that pgx would handle natively in binary, but since it is not a // database/sql native type should be passed through as a string func TestConnQueryRowPgxBinary(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { sql := "select $1::int4[]" expected := "{1,2,3}" var actual string @@ -401,7 +434,7 @@ func TestConnQueryRowPgxBinary(t *testing.T) { } func TestConnQueryRowUnknownType(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { skipCockroachDB(t, db, "Server does not support point type") sql := "select $1::point" @@ -415,7 +448,7 @@ func TestConnQueryRowUnknownType(t *testing.T) { } func TestConnQueryJSONIntoByteSlice(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { _, err := db.Exec(` create temporary table docs( body json not null @@ -475,7 +508,7 @@ func TestConnExecInsertByteSliceIntoJSON(t *testing.T) { } func TestTransactionLifeCycle(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { _, err := db.Exec("create temporary table t(a varchar not null)") require.NoError(t, err) @@ -509,7 +542,7 @@ func TestTransactionLifeCycle(t *testing.T) { } func TestConnBeginTxIsolation(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { skipCockroachDB(t, db, "Server always uses serializable isolation level") var defaultIsoLevel string @@ -565,7 +598,7 @@ func TestConnBeginTxIsolation(t *testing.T) { } func TestConnBeginTxReadOnly(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { tx, err := db.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}) require.NoError(t, err) defer tx.Rollback() @@ -583,7 +616,7 @@ func TestConnBeginTxReadOnly(t *testing.T) { } func TestBeginTxContextCancel(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { _, err := db.Exec("drop table if exists t") require.NoError(t, err) @@ -610,49 +643,13 @@ func TestBeginTxContextCancel(t *testing.T) { }) } -func TestAcquireConn(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { - var conns []*pgx.Conn - - for i := 1; i < 6; i++ { - conn, err := stdlib.AcquireConn(db) - if err != nil { - t.Errorf("%d. AcquireConn failed: %v", i, err) - continue - } - - var n int32 - err = conn.QueryRow(context.Background(), "select 1").Scan(&n) - if err != nil { - t.Errorf("%d. QueryRow failed: %v", i, err) - } - if n != 1 { - t.Errorf("%d. n => %d, want %d", i, n, 1) - } - - stats := db.Stats() - if stats.OpenConnections != i { - t.Errorf("%d. stats.OpenConnections => %d, want %d", i, stats.OpenConnections, i) - } - - conns = append(conns, conn) - } - - for i, conn := range conns { - if err := stdlib.ReleaseConn(db, conn); err != nil { - t.Errorf("%d. stdlib.ReleaseConn failed: %v", i, err) - } - } - }) -} - func TestConnRaw(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { conn, err := db.Conn(context.Background()) require.NoError(t, err) var n int - err = conn.Raw(func(driverConn interface{}) error { + err = conn.Raw(func(driverConn any) error { conn := driverConn.(*stdlib.Conn).Conn() return conn.QueryRow(context.Background(), "select 42").Scan(&n) }) @@ -661,47 +658,15 @@ func TestConnRaw(t *testing.T) { }) } -// https://github.com/jackc/pgx/issues/673 -func TestReleaseConnWithTxInProgress(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { - skipCockroachDB(t, db, "Server does not support backend PID") - - c1, err := stdlib.AcquireConn(db) - require.NoError(t, err) - - _, err = c1.Exec(context.Background(), "begin") - require.NoError(t, err) - - c1PID := c1.PgConn().PID() - - err = stdlib.ReleaseConn(db, c1) - require.NoError(t, err) - - c2, err := stdlib.AcquireConn(db) - require.NoError(t, err) - - c2PID := c2.PgConn().PID() - - err = stdlib.ReleaseConn(db, c2) - require.NoError(t, err) - - require.NotEqual(t, c1PID, c2PID) - - // Releasing a conn with a tx in progress should close the connection - stats := db.Stats() - require.Equal(t, 1, stats.OpenConnections) - }) -} - func TestConnPingContextSuccess(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { err := db.PingContext(context.Background()) require.NoError(t, err) }) } func TestConnPrepareContextSuccess(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { stmt, err := db.PrepareContext(context.Background(), "select now()") require.NoError(t, err) err = stmt.Close() @@ -710,31 +675,14 @@ func TestConnPrepareContextSuccess(t *testing.T) { } func TestConnExecContextSuccess(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { _, err := db.ExecContext(context.Background(), "create temporary table exec_context_test(id serial primary key)") require.NoError(t, err) }) } -func TestConnExecContextFailureRetry(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { - // We get a connection, immediately close it, and then get it back; - // DB.Conn along with Conn.ResetSession does the retry for us. - { - conn, err := stdlib.AcquireConn(db) - require.NoError(t, err) - conn.Close(context.Background()) - stdlib.ReleaseConn(db, conn) - } - conn, err := db.Conn(context.Background()) - require.NoError(t, err) - _, err = conn.ExecContext(context.Background(), "select 1") - require.NoError(t, err) - }) -} - func TestConnQueryContextSuccess(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { rows, err := db.QueryContext(context.Background(), "select * from generate_series(1,10) n") require.NoError(t, err) @@ -747,26 +695,8 @@ func TestConnQueryContextSuccess(t *testing.T) { }) } -func TestConnQueryContextFailureRetry(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { - // We get a connection, immediately close it, and then get it back; - // DB.Conn along with Conn.ResetSession does the retry for us. - { - conn, err := stdlib.AcquireConn(db) - require.NoError(t, err) - conn.Close(context.Background()) - stdlib.ReleaseConn(db, conn) - } - conn, err := db.Conn(context.Background()) - require.NoError(t, err) - - _, err = conn.QueryContext(context.Background(), "select 1") - require.NoError(t, err) - }) -} - func TestRowsColumnTypeDatabaseTypeName(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { rows, err := db.Query("select 42::bigint") require.NoError(t, err) @@ -850,7 +780,7 @@ func TestStmtQueryContextSuccess(t *testing.T) { } func TestRowsColumnTypes(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { columnTypesTests := []struct { Name string TypeName string @@ -988,7 +918,7 @@ func TestRowsColumnTypes(t *testing.T) { } func TestQueryLifeCycle(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") rows, err := db.Query("SELECT 'foo', n FROM generate_series($1::int, $2::int) n WHERE 3 = $3", 1, 10, 3) @@ -1037,7 +967,7 @@ func TestQueryLifeCycle(t *testing.T) { // https://github.com/jackc/pgx/issues/409 func TestScanJSONIntoJSONRawMessage(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { var msg json.RawMessage err := db.QueryRow("select '{}'::json").Scan(&msg) @@ -1047,16 +977,16 @@ func TestScanJSONIntoJSONRawMessage(t *testing.T) { } type testLog struct { - lvl pgx.LogLevel + lvl tracelog.LogLevel msg string - data map[string]interface{} + data map[string]any } type testLogger struct { logs []testLog } -func (l *testLogger) Log(ctx context.Context, lvl pgx.LogLevel, msg string, data map[string]interface{}) { +func (l *testLogger) Log(ctx context.Context, lvl tracelog.LogLevel, msg string, data map[string]any) { l.logs = append(l.logs, testLog{lvl: lvl, msg: msg, data: data}) } @@ -1065,7 +995,7 @@ func TestRegisterConnConfig(t *testing.T) { require.NoError(t, err) logger := &testLogger{} - connConfig.Logger = logger + connConfig.Tracer = &tracelog.TraceLog{Logger: logger, LogLevel: tracelog.LogLevelInfo} // Issue 947: Register and unregister a ConnConfig and ensure that the // returned connection string is not reused. @@ -1092,8 +1022,8 @@ func TestRegisterConnConfig(t *testing.T) { // https://github.com/jackc/pgx/issues/958 func TestConnQueryRowConstraintErrors(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { - skipPostgreSQLVersion(t, db, "< 11", "Test requires PG 11+") + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + skipPostgreSQLVersionLessThan(t, db, 11) skipCockroachDB(t, db, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") _, err := db.Exec(`create temporary table defer_test ( @@ -1225,3 +1155,67 @@ func TestResetSessionHookCalled(t *testing.T) { require.True(t, mockCalled) } + +func TestCheckIdleConn(t *testing.T) { + controllerConn, err := sql.Open("pgx", os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeDB(t, controllerConn) + + skipCockroachDB(t, controllerConn, "Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)") + + db, err := sql.Open("pgx", os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeDB(t, db) + + var conns []*sql.Conn + for i := 0; i < 3; i++ { + c, err := db.Conn(context.Background()) + require.NoError(t, err) + conns = append(conns, c) + } + + require.EqualValues(t, 3, db.Stats().OpenConnections) + + var pids []uint32 + for _, c := range conns { + err := c.Raw(func(driverConn any) error { + pids = append(pids, driverConn.(*stdlib.Conn).Conn().PgConn().PID()) + return nil + }) + require.NoError(t, err) + err = c.Close() + require.NoError(t, err) + } + + // The database/sql connection pool seems to automatically close idle connections to only keep 2 alive. + // require.EqualValues(t, 3, db.Stats().OpenConnections) + + _, err = controllerConn.ExecContext(context.Background(), `select pg_terminate_backend(n) from unnest($1::int[]) n`, pids) + require.NoError(t, err) + + // All conns are dead they don't know it and neither does the pool. But because of database/sql automatically closing + // idle connections we can't be sure how many we should have. require.EqualValues(t, 3, db.Stats().OpenConnections) + + // Wait long enough so the pool will realize it needs to check the connections. + time.Sleep(time.Second) + + // Pool should try all existing connections and find them dead, then create a new connection which should successfully ping. + err = db.PingContext(context.Background()) + require.NoError(t, err) + + // The original 3 conns should have been terminated and the a new conn established for the ping. + require.EqualValues(t, 1, db.Stats().OpenConnections) + c, err := db.Conn(context.Background()) + require.NoError(t, err) + + var cPID uint32 + err = c.Raw(func(driverConn any) error { + cPID = driverConn.(*stdlib.Conn).Conn().PgConn().PID() + return nil + }) + require.NoError(t, err) + err = c.Close() + require.NoError(t, err) + + require.NotContains(t, pids, cPID) +} diff --git a/tracelog/tracelog.go b/tracelog/tracelog.go new file mode 100644 index 00000000..d51b9b95 --- /dev/null +++ b/tracelog/tracelog.go @@ -0,0 +1,295 @@ +// Package tracelog provides a tracer that acts as a traditional logger. +package tracelog + +import ( + "context" + "encoding/hex" + "errors" + "fmt" + "time" + + "github.com/jackc/pgx/v5" +) + +// LogLevel represents the pgx logging level. See LogLevel* constants for +// possible values. +type LogLevel int + +// The values for log levels are chosen such that the zero value means that no +// log level was specified. +const ( + LogLevelTrace = LogLevel(6) + LogLevelDebug = LogLevel(5) + LogLevelInfo = LogLevel(4) + LogLevelWarn = LogLevel(3) + LogLevelError = LogLevel(2) + LogLevelNone = LogLevel(1) +) + +func (ll LogLevel) String() string { + switch ll { + case LogLevelTrace: + return "trace" + case LogLevelDebug: + return "debug" + case LogLevelInfo: + return "info" + case LogLevelWarn: + return "warn" + case LogLevelError: + return "error" + case LogLevelNone: + return "none" + default: + return fmt.Sprintf("invalid level %d", ll) + } +} + +// Logger is the interface used to get log output from pgx. +type Logger interface { + // Log a message at the given level with data key/value pairs. data may be nil. + Log(ctx context.Context, level LogLevel, msg string, data map[string]any) +} + +// LoggerFunc is a wrapper around a function to satisfy the pgx.Logger interface +type LoggerFunc func(ctx context.Context, level LogLevel, msg string, data map[string]interface{}) + +// Log delegates the logging request to the wrapped function +func (f LoggerFunc) Log(ctx context.Context, level LogLevel, msg string, data map[string]interface{}) { + f(ctx, level, msg, data) +} + +// LogLevelFromString converts log level string to constant +// +// Valid levels: +// trace +// debug +// info +// warn +// error +// none +func LogLevelFromString(s string) (LogLevel, error) { + switch s { + case "trace": + return LogLevelTrace, nil + case "debug": + return LogLevelDebug, nil + case "info": + return LogLevelInfo, nil + case "warn": + return LogLevelWarn, nil + case "error": + return LogLevelError, nil + case "none": + return LogLevelNone, nil + default: + return 0, errors.New("invalid log level") + } +} + +func logQueryArgs(args []any) []any { + logArgs := make([]any, 0, len(args)) + + for _, a := range args { + switch v := a.(type) { + case []byte: + if len(v) < 64 { + a = hex.EncodeToString(v) + } else { + a = fmt.Sprintf("%x (truncated %d bytes)", v[:64], len(v)-64) + } + case string: + if len(v) > 64 { + a = fmt.Sprintf("%s (truncated %d bytes)", v[:64], len(v)-64) + } + } + logArgs = append(logArgs, a) + } + + return logArgs +} + +// TraceLog implements pgx.QueryTracer, pgx.BatchTracer, pgx.ConnectTracer, and pgx.CopyFromTracer. All fields are +// required. +type TraceLog struct { + Logger Logger + LogLevel LogLevel +} + +type ctxKey int + +const ( + _ ctxKey = iota + tracelogQueryCtxKey + tracelogBatchCtxKey + tracelogCopyFromCtxKey + tracelogConnectCtxKey +) + +type traceQueryData struct { + startTime time.Time + sql string + args []any +} + +func (tl *TraceLog) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context { + return context.WithValue(ctx, tracelogQueryCtxKey, &traceQueryData{ + startTime: time.Now(), + sql: data.SQL, + args: data.Args, + }) +} + +func (tl *TraceLog) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) { + queryData := ctx.Value(tracelogQueryCtxKey).(*traceQueryData) + + endTime := time.Now() + interval := endTime.Sub(queryData.startTime) + + if data.Err != nil { + if tl.shouldLog(LogLevelError) { + tl.log(ctx, conn, LogLevelError, "Query", map[string]any{"sql": queryData.sql, "args": logQueryArgs(queryData.args), "err": data.Err, "time": interval}) + } + return + } + + if tl.shouldLog(LogLevelInfo) { + tl.log(ctx, conn, LogLevelInfo, "Query", map[string]any{"sql": queryData.sql, "args": logQueryArgs(queryData.args), "time": interval, "commandTag": data.CommandTag.String()}) + } +} + +type traceBatchData struct { + startTime time.Time +} + +func (tl *TraceLog) TraceBatchStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context { + return context.WithValue(ctx, tracelogBatchCtxKey, &traceBatchData{ + startTime: time.Now(), + }) +} + +func (tl *TraceLog) TraceBatchQuery(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) { + if data.Err != nil { + if tl.shouldLog(LogLevelError) { + tl.log(ctx, conn, LogLevelError, "BatchQuery", map[string]any{"sql": data.SQL, "args": logQueryArgs(data.Args), "err": data.Err}) + } + return + } + + if tl.shouldLog(LogLevelInfo) { + tl.log(ctx, conn, LogLevelInfo, "BatchQuery", map[string]any{"sql": data.SQL, "args": logQueryArgs(data.Args), "commandTag": data.CommandTag.String()}) + } +} + +func (tl *TraceLog) TraceBatchEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) { + queryData := ctx.Value(tracelogBatchCtxKey).(*traceBatchData) + + endTime := time.Now() + interval := endTime.Sub(queryData.startTime) + + if data.Err != nil { + if tl.shouldLog(LogLevelError) { + tl.log(ctx, conn, LogLevelError, "BatchClose", map[string]any{"err": data.Err, "time": interval}) + } + return + } + + if tl.shouldLog(LogLevelInfo) { + tl.log(ctx, conn, LogLevelInfo, "BatchClose", map[string]any{"time": interval}) + } +} + +type traceCopyFromData struct { + startTime time.Time + TableName pgx.Identifier + ColumnNames []string +} + +func (tl *TraceLog) TraceCopyFromStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context { + return context.WithValue(ctx, tracelogCopyFromCtxKey, &traceCopyFromData{ + startTime: time.Now(), + TableName: data.TableName, + ColumnNames: data.ColumnNames, + }) +} + +func (tl *TraceLog) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) { + copyFromData := ctx.Value(tracelogCopyFromCtxKey).(*traceCopyFromData) + + endTime := time.Now() + interval := endTime.Sub(copyFromData.startTime) + + if data.Err != nil { + if tl.shouldLog(LogLevelError) { + tl.log(ctx, conn, LogLevelError, "CopyFrom", map[string]any{"tableName": copyFromData.TableName, "columnNames": copyFromData.ColumnNames, "err": data.Err, "time": interval}) + } + return + } + + if tl.shouldLog(LogLevelInfo) { + tl.log(ctx, conn, LogLevelInfo, "CopyFrom", map[string]any{"tableName": copyFromData.TableName, "columnNames": copyFromData.ColumnNames, "err": data.Err, "time": interval, "rowCount": data.CommandTag.RowsAffected()}) + } +} + +type traceConnectData struct { + startTime time.Time + connConfig *pgx.ConnConfig +} + +func (tl *TraceLog) TraceConnectStart(ctx context.Context, data pgx.TraceConnectStartData) context.Context { + return context.WithValue(ctx, tracelogConnectCtxKey, &traceConnectData{ + startTime: time.Now(), + connConfig: data.ConnConfig, + }) +} + +func (tl *TraceLog) TraceConnectEnd(ctx context.Context, data pgx.TraceConnectEndData) { + connectData := ctx.Value(tracelogConnectCtxKey).(*traceConnectData) + + endTime := time.Now() + interval := endTime.Sub(connectData.startTime) + + if data.Err != nil { + if tl.shouldLog(LogLevelError) { + tl.Logger.Log(ctx, LogLevelError, "Connect", map[string]any{ + "host": connectData.connConfig.Host, + "port": connectData.connConfig.Port, + "database": connectData.connConfig.Database, + "time": interval, + "err": data.Err, + }) + } + return + } + + if data.Conn != nil { + if tl.shouldLog(LogLevelInfo) { + tl.log(ctx, data.Conn, LogLevelInfo, "Connect", map[string]any{ + "host": connectData.connConfig.Host, + "port": connectData.connConfig.Port, + "database": connectData.connConfig.Database, + "time": interval, + }) + } + } +} + +func (tl *TraceLog) shouldLog(lvl LogLevel) bool { + return tl.LogLevel >= lvl +} + +func (tl *TraceLog) log(ctx context.Context, conn *pgx.Conn, lvl LogLevel, msg string, data map[string]any) { + if data == nil { + data = map[string]any{} + } + + pgConn := conn.PgConn() + if pgConn != nil { + pid := pgConn.PID() + if pid != 0 { + data["pid"] = pid + } + } + + tl.Logger.Log(ctx, lvl, msg, data) +} diff --git a/tracelog/tracelog_test.go b/tracelog/tracelog_test.go new file mode 100644 index 00000000..ed0f8eab --- /dev/null +++ b/tracelog/tracelog_test.go @@ -0,0 +1,301 @@ +package tracelog_test + +import ( + "bytes" + "context" + "log" + "os" + "strings" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/jackc/pgx/v5/tracelog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var defaultConnTestRunner pgxtest.ConnTestRunner + +func init() { + defaultConnTestRunner = pgxtest.DefaultConnTestRunner() + defaultConnTestRunner.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + return config + } +} + +type testLog struct { + lvl tracelog.LogLevel + msg string + data map[string]any +} + +type testLogger struct { + logs []testLog +} + +func (l *testLogger) Log(ctx context.Context, level tracelog.LogLevel, msg string, data map[string]any) { + data["ctxdata"] = ctx.Value("ctxdata") + l.logs = append(l.logs, testLog{lvl: level, msg: msg, data: data}) +} + +func TestContextGetsPassedToLogMethod(t *testing.T) { + t.Parallel() + + logger := &testLogger{} + tracer := &tracelog.TraceLog{ + Logger: logger, + LogLevel: tracelog.LogLevelTrace, + } + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + logger.logs = logger.logs[0:0] // Clear any logs written when establishing connection + + ctx = context.WithValue(context.Background(), "ctxdata", "foo") + _, err := conn.Exec(ctx, `;`) + require.NoError(t, err) + require.Len(t, logger.logs, 1) + require.Equal(t, "foo", logger.logs[0].data["ctxdata"]) + }) +} + +func TestLoggerFunc(t *testing.T) { + t.Parallel() + + const testMsg = "foo" + + buf := bytes.Buffer{} + logger := log.New(&buf, "", 0) + + createAdapterFn := func(logger *log.Logger) tracelog.LoggerFunc { + return func(ctx context.Context, level tracelog.LogLevel, msg string, data map[string]interface{}) { + logger.Printf("%s", testMsg) + } + } + + config := defaultConnTestRunner.CreateConfig(context.Background(), t) + config.Tracer = &tracelog.TraceLog{ + Logger: createAdapterFn(logger), + LogLevel: tracelog.LogLevelTrace, + } + + conn, err := pgx.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer conn.Close(context.Background()) + + buf.Reset() // Clear logs written when establishing connection + + if _, err := conn.Exec(context.TODO(), ";"); err != nil { + t.Fatal(err) + } + + if strings.TrimSpace(buf.String()) != testMsg { + t.Errorf("Expected logger function to return '%s', but it was '%s'", testMsg, buf.String()) + } +} + +func TestLogQuery(t *testing.T) { + t.Parallel() + + logger := &testLogger{} + tracer := &tracelog.TraceLog{ + Logger: logger, + LogLevel: tracelog.LogLevelTrace, + } + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + logger.logs = logger.logs[0:0] // Clear any logs written when establishing connection + + _, err := conn.Exec(ctx, `select $1::text`, "testing") + require.NoError(t, err) + require.Len(t, logger.logs, 1) + require.Equal(t, "Query", logger.logs[0].msg) + require.Equal(t, tracelog.LogLevelInfo, logger.logs[0].lvl) + + _, err = conn.Exec(ctx, `foo`, "testing") + require.Error(t, err) + require.Len(t, logger.logs, 2) + require.Equal(t, "Query", logger.logs[1].msg) + require.Equal(t, tracelog.LogLevelError, logger.logs[1].lvl) + require.Equal(t, err, logger.logs[1].data["err"]) + }) +} + +func TestLogCopyFrom(t *testing.T) { + t.Parallel() + + logger := &testLogger{} + tracer := &tracelog.TraceLog{ + Logger: logger, + LogLevel: tracelog.LogLevelTrace, + } + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + _, err := conn.Exec(context.Background(), `create temporary table foo(a int4)`) + require.NoError(t, err) + + logger.logs = logger.logs[0:0] + + inputRows := [][]any{ + {int32(1)}, + {nil}, + } + + copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows)) + require.NoError(t, err) + require.EqualValues(t, len(inputRows), copyCount) + require.Len(t, logger.logs, 1) + require.Equal(t, "CopyFrom", logger.logs[0].msg) + require.Equal(t, tracelog.LogLevelInfo, logger.logs[0].lvl) + + logger.logs = logger.logs[0:0] + + inputRows = [][]any{ + {"not an integer"}, + {nil}, + } + + copyCount, err = conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows)) + require.Error(t, err) + require.EqualValues(t, 0, copyCount) + require.Len(t, logger.logs, 1) + require.Equal(t, "CopyFrom", logger.logs[0].msg) + require.Equal(t, tracelog.LogLevelError, logger.logs[0].lvl) + }) +} + +func TestLogConnect(t *testing.T) { + t.Parallel() + + logger := &testLogger{} + tracer := &tracelog.TraceLog{ + Logger: logger, + LogLevel: tracelog.LogLevelTrace, + } + + config := defaultConnTestRunner.CreateConfig(context.Background(), t) + config.Tracer = tracer + + conn1, err := pgx.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer conn1.Close(context.Background()) + require.Len(t, logger.logs, 1) + require.Equal(t, "Connect", logger.logs[0].msg) + require.Equal(t, tracelog.LogLevelInfo, logger.logs[0].lvl) + + logger.logs = logger.logs[0:0] + + config, err = pgx.ParseConfig("host=/invalid") + require.NoError(t, err) + config.Tracer = tracer + + conn2, err := pgx.ConnectConfig(context.Background(), config) + require.Nil(t, conn2) + require.Error(t, err) + require.Len(t, logger.logs, 1) + require.Equal(t, "Connect", logger.logs[0].msg) + require.Equal(t, tracelog.LogLevelError, logger.logs[0].lvl) +} + +func TestLogBatchStatementsOnExec(t *testing.T) { + t.Parallel() + + logger := &testLogger{} + tracer := &tracelog.TraceLog{ + Logger: logger, + LogLevel: tracelog.LogLevelTrace, + } + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + logger.logs = logger.logs[0:0] // Clear any logs written when establishing connection + + batch := &pgx.Batch{} + batch.Queue("create table foo (id bigint)") + batch.Queue("drop table foo") + + br := conn.SendBatch(context.Background(), batch) + + _, err := br.Exec() + require.NoError(t, err) + + _, err = br.Exec() + require.NoError(t, err) + + err = br.Close() + require.NoError(t, err) + + require.Len(t, logger.logs, 3) + assert.Equal(t, "BatchQuery", logger.logs[0].msg) + assert.Equal(t, "create table foo (id bigint)", logger.logs[0].data["sql"]) + assert.Equal(t, "BatchQuery", logger.logs[1].msg) + assert.Equal(t, "drop table foo", logger.logs[1].data["sql"]) + assert.Equal(t, "BatchClose", logger.logs[2].msg) + + }) +} + +func TestLogBatchStatementsOnBatchResultClose(t *testing.T) { + t.Parallel() + + logger := &testLogger{} + tracer := &tracelog.TraceLog{ + Logger: logger, + LogLevel: tracelog.LogLevelTrace, + } + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + logger.logs = logger.logs[0:0] // Clear any logs written when establishing connection + + batch := &pgx.Batch{} + batch.Queue("select generate_series(1,$1)", 100) + batch.Queue("select 1 = 1;") + + br := conn.SendBatch(context.Background(), batch) + err := br.Close() + require.NoError(t, err) + + require.Len(t, logger.logs, 3) + assert.Equal(t, "BatchQuery", logger.logs[0].msg) + assert.Equal(t, "select generate_series(1,$1)", logger.logs[0].data["sql"]) + assert.Equal(t, "BatchQuery", logger.logs[1].msg) + assert.Equal(t, "select 1 = 1;", logger.logs[1].data["sql"]) + assert.Equal(t, "BatchClose", logger.logs[2].msg) + }) +} diff --git a/tracer.go b/tracer.go new file mode 100644 index 00000000..58ca99f7 --- /dev/null +++ b/tracer.go @@ -0,0 +1,107 @@ +package pgx + +import ( + "context" + + "github.com/jackc/pgx/v5/pgconn" +) + +// QueryTracer traces Query, QueryRow, and Exec. +type QueryTracer interface { + // TraceQueryStart is called at the beginning of Query, QueryRow, and Exec calls. The returned context is used for the + // rest of the call and will be passed to TraceQueryEnd. + TraceQueryStart(ctx context.Context, conn *Conn, data TraceQueryStartData) context.Context + + TraceQueryEnd(ctx context.Context, conn *Conn, data TraceQueryEndData) +} + +type TraceQueryStartData struct { + SQL string + Args []any +} + +type TraceQueryEndData struct { + CommandTag pgconn.CommandTag + Err error +} + +// BatchTracer traces SendBatch. +type BatchTracer interface { + // TraceBatchStart is called at the beginning of SendBatch calls. The returned context is used for the + // rest of the call and will be passed to TraceBatchQuery and TraceBatchEnd. + TraceBatchStart(ctx context.Context, conn *Conn, data TraceBatchStartData) context.Context + + TraceBatchQuery(ctx context.Context, conn *Conn, data TraceBatchQueryData) + TraceBatchEnd(ctx context.Context, conn *Conn, data TraceBatchEndData) +} + +type TraceBatchStartData struct { + Batch *Batch +} + +type TraceBatchQueryData struct { + SQL string + Args []any + CommandTag pgconn.CommandTag + Err error +} + +type TraceBatchEndData struct { + Err error +} + +// CopyFromTracer traces CopyFrom. +type CopyFromTracer interface { + // TraceCopyFromStart is called at the beginning of CopyFrom calls. The returned context is used for the + // rest of the call and will be passed to TraceCopyFromEnd. + TraceCopyFromStart(ctx context.Context, conn *Conn, data TraceCopyFromStartData) context.Context + + TraceCopyFromEnd(ctx context.Context, conn *Conn, data TraceCopyFromEndData) +} + +type TraceCopyFromStartData struct { + TableName Identifier + ColumnNames []string +} + +type TraceCopyFromEndData struct { + CommandTag pgconn.CommandTag + Err error +} + +// PrepareTracer traces Prepare. +type PrepareTracer interface { + // TracePrepareStart is called at the beginning of Prepare calls. The returned context is used for the + // rest of the call and will be passed to TracePrepareEnd. + TracePrepareStart(ctx context.Context, conn *Conn, data TracePrepareStartData) context.Context + + TracePrepareEnd(ctx context.Context, conn *Conn, data TracePrepareEndData) +} + +type TracePrepareStartData struct { + Name string + SQL string +} + +type TracePrepareEndData struct { + AlreadyPrepared bool + Err error +} + +// ConnectTracer traces Connect and ConnectConfig. +type ConnectTracer interface { + // TraceConnectStart is called at the beginning of Connect and ConnectConfig calls. The returned context is used for + // the rest of the call and will be passed to TraceConnectEnd. + TraceConnectStart(ctx context.Context, data TraceConnectStartData) context.Context + + TraceConnectEnd(ctx context.Context, data TraceConnectEndData) +} + +type TraceConnectStartData struct { + ConnConfig *ConnConfig +} + +type TraceConnectEndData struct { + Conn *Conn + Err error +} diff --git a/tracer_test.go b/tracer_test.go new file mode 100644 index 00000000..86375b34 --- /dev/null +++ b/tracer_test.go @@ -0,0 +1,538 @@ +package pgx_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/require" +) + +type testTracer struct { + traceQueryStart func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context + traceQueryEnd func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) + traceBatchStart func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context + traceBatchQuery func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) + traceBatchEnd func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) + traceCopyFromStart func(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context + traceCopyFromEnd func(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) + tracePrepareStart func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareStartData) context.Context + tracePrepareEnd func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) + traceConnectStart func(ctx context.Context, data pgx.TraceConnectStartData) context.Context + traceConnectEnd func(ctx context.Context, data pgx.TraceConnectEndData) +} + +func (tt *testTracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context { + if tt.traceQueryStart != nil { + return tt.traceQueryStart(ctx, conn, data) + } + return ctx +} + +func (tt *testTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) { + if tt.traceQueryEnd != nil { + tt.traceQueryEnd(ctx, conn, data) + } +} + +func (tt *testTracer) TraceBatchStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context { + if tt.traceBatchStart != nil { + return tt.traceBatchStart(ctx, conn, data) + } + return ctx +} + +func (tt *testTracer) TraceBatchQuery(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) { + if tt.traceBatchQuery != nil { + tt.traceBatchQuery(ctx, conn, data) + } +} + +func (tt *testTracer) TraceBatchEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) { + if tt.traceBatchEnd != nil { + tt.traceBatchEnd(ctx, conn, data) + } +} + +func (tt *testTracer) TraceCopyFromStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context { + if tt.traceCopyFromStart != nil { + return tt.traceCopyFromStart(ctx, conn, data) + } + return ctx +} + +func (tt *testTracer) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) { + if tt.traceCopyFromEnd != nil { + tt.traceCopyFromEnd(ctx, conn, data) + } +} + +func (tt *testTracer) TracePrepareStart(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareStartData) context.Context { + if tt.tracePrepareStart != nil { + return tt.tracePrepareStart(ctx, conn, data) + } + return ctx +} + +func (tt *testTracer) TracePrepareEnd(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) { + if tt.tracePrepareEnd != nil { + tt.tracePrepareEnd(ctx, conn, data) + } +} + +func (tt *testTracer) TraceConnectStart(ctx context.Context, data pgx.TraceConnectStartData) context.Context { + if tt.traceConnectStart != nil { + return tt.traceConnectStart(ctx, data) + } + return ctx +} + +func (tt *testTracer) TraceConnectEnd(ctx context.Context, data pgx.TraceConnectEndData) { + if tt.traceConnectEnd != nil { + tt.traceConnectEnd(ctx, data) + } +} + +func TestTraceExec(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + traceQueryStartCalled := false + tracer.traceQueryStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context { + traceQueryStartCalled = true + require.Equal(t, `select $1::text`, data.SQL) + require.Len(t, data.Args, 1) + require.Equal(t, `testing`, data.Args[0]) + return context.WithValue(ctx, "fromTraceQueryStart", "foo") + } + + traceQueryEndCalled := false + tracer.traceQueryEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) { + traceQueryEndCalled = true + require.Equal(t, "foo", ctx.Value("fromTraceQueryStart")) + require.Equal(t, `SELECT 1`, data.CommandTag.String()) + require.NoError(t, data.Err) + } + + _, err := conn.Exec(ctx, `select $1::text`, "testing") + require.NoError(t, err) + require.True(t, traceQueryStartCalled) + require.True(t, traceQueryEndCalled) + }) +} + +func TestTraceQuery(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + traceQueryStartCalled := false + tracer.traceQueryStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context { + traceQueryStartCalled = true + require.Equal(t, `select $1::text`, data.SQL) + require.Len(t, data.Args, 1) + require.Equal(t, `testing`, data.Args[0]) + return context.WithValue(ctx, "fromTraceQueryStart", "foo") + } + + traceQueryEndCalled := false + tracer.traceQueryEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) { + traceQueryEndCalled = true + require.Equal(t, "foo", ctx.Value("fromTraceQueryStart")) + require.Equal(t, `SELECT 1`, data.CommandTag.String()) + require.NoError(t, data.Err) + } + + var s string + err := conn.QueryRow(ctx, `select $1::text`, "testing").Scan(&s) + require.NoError(t, err) + require.Equal(t, "testing", s) + require.True(t, traceQueryStartCalled) + require.True(t, traceQueryEndCalled) + }) +} + +func TestTraceBatchNormal(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + traceBatchStartCalled := false + tracer.traceBatchStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context { + traceBatchStartCalled = true + require.NotNil(t, data.Batch) + require.Equal(t, 2, data.Batch.Len()) + return context.WithValue(ctx, "fromTraceBatchStart", "foo") + } + + traceBatchQueryCalledCount := 0 + tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) { + traceBatchQueryCalledCount++ + require.Equal(t, "foo", ctx.Value("fromTraceBatchStart")) + require.NoError(t, data.Err) + } + + traceBatchEndCalled := false + tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) { + traceBatchEndCalled = true + require.Equal(t, "foo", ctx.Value("fromTraceBatchStart")) + require.NoError(t, data.Err) + } + + batch := &pgx.Batch{} + batch.Queue(`select 1`) + batch.Queue(`select 2`) + + br := conn.SendBatch(context.Background(), batch) + require.True(t, traceBatchStartCalled) + + var n int32 + err := br.QueryRow().Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 1, n) + require.EqualValues(t, 1, traceBatchQueryCalledCount) + + err = br.QueryRow().Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 2, n) + require.EqualValues(t, 2, traceBatchQueryCalledCount) + + err = br.Close() + require.NoError(t, err) + + require.True(t, traceBatchEndCalled) + }) +} + +func TestTraceBatchClose(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + traceBatchStartCalled := false + tracer.traceBatchStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context { + traceBatchStartCalled = true + require.NotNil(t, data.Batch) + require.Equal(t, 2, data.Batch.Len()) + return context.WithValue(ctx, "fromTraceBatchStart", "foo") + } + + traceBatchQueryCalledCount := 0 + tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) { + traceBatchQueryCalledCount++ + require.Equal(t, "foo", ctx.Value("fromTraceBatchStart")) + require.NoError(t, data.Err) + } + + traceBatchEndCalled := false + tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) { + traceBatchEndCalled = true + require.Equal(t, "foo", ctx.Value("fromTraceBatchStart")) + require.NoError(t, data.Err) + } + + batch := &pgx.Batch{} + batch.Queue(`select 1`) + batch.Queue(`select 2`) + + br := conn.SendBatch(context.Background(), batch) + require.True(t, traceBatchStartCalled) + err := br.Close() + require.NoError(t, err) + require.EqualValues(t, 2, traceBatchQueryCalledCount) + require.True(t, traceBatchEndCalled) + }) +} + +func TestTraceBatchErrorWhileReadingResults(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, []pgx.QueryExecMode{pgx.QueryExecModeSimpleProtocol}, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + traceBatchStartCalled := false + tracer.traceBatchStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context { + traceBatchStartCalled = true + require.NotNil(t, data.Batch) + require.Equal(t, 3, data.Batch.Len()) + return context.WithValue(ctx, "fromTraceBatchStart", "foo") + } + + traceBatchQueryCalledCount := 0 + tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) { + traceBatchQueryCalledCount++ + require.Equal(t, "foo", ctx.Value("fromTraceBatchStart")) + if traceBatchQueryCalledCount == 2 { + require.Error(t, data.Err) + } else { + require.NoError(t, data.Err) + } + } + + traceBatchEndCalled := false + tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) { + traceBatchEndCalled = true + require.Equal(t, "foo", ctx.Value("fromTraceBatchStart")) + require.Error(t, data.Err) + } + + batch := &pgx.Batch{} + batch.Queue(`select 1`) + batch.Queue(`select 2/n-2 from generate_series(0,10) n`) + batch.Queue(`select 3`) + + br := conn.SendBatch(context.Background(), batch) + require.True(t, traceBatchStartCalled) + + commandTag, err := br.Exec() + require.NoError(t, err) + require.Equal(t, "SELECT 1", commandTag.String()) + + commandTag, err = br.Exec() + require.Error(t, err) + require.Equal(t, "", commandTag.String()) + + commandTag, err = br.Exec() + require.Error(t, err) + require.Equal(t, "", commandTag.String()) + + err = br.Close() + require.Error(t, err) + require.EqualValues(t, 2, traceBatchQueryCalledCount) + require.True(t, traceBatchEndCalled) + }) +} + +func TestTraceBatchErrorWhileReadingResultsWhileClosing(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, []pgx.QueryExecMode{pgx.QueryExecModeSimpleProtocol}, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + traceBatchStartCalled := false + tracer.traceBatchStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context { + traceBatchStartCalled = true + require.NotNil(t, data.Batch) + require.Equal(t, 3, data.Batch.Len()) + return context.WithValue(ctx, "fromTraceBatchStart", "foo") + } + + traceBatchQueryCalledCount := 0 + tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) { + traceBatchQueryCalledCount++ + require.Equal(t, "foo", ctx.Value("fromTraceBatchStart")) + if traceBatchQueryCalledCount == 2 { + require.Error(t, data.Err) + } else { + require.NoError(t, data.Err) + } + } + + traceBatchEndCalled := false + tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) { + traceBatchEndCalled = true + require.Equal(t, "foo", ctx.Value("fromTraceBatchStart")) + require.Error(t, data.Err) + } + + batch := &pgx.Batch{} + batch.Queue(`select 1`) + batch.Queue(`select 2/n-2 from generate_series(0,10) n`) + batch.Queue(`select 3`) + + br := conn.SendBatch(context.Background(), batch) + require.True(t, traceBatchStartCalled) + err := br.Close() + require.Error(t, err) + require.EqualValues(t, 2, traceBatchQueryCalledCount) + require.True(t, traceBatchEndCalled) + }) +} + +func TestTraceCopyFrom(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + traceCopyFromStartCalled := false + tracer.traceCopyFromStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context { + traceCopyFromStartCalled = true + require.Equal(t, pgx.Identifier{"foo"}, data.TableName) + require.Equal(t, []string{"a"}, data.ColumnNames) + return context.WithValue(ctx, "fromTraceCopyFromStart", "foo") + } + + traceCopyFromEndCalled := false + tracer.traceCopyFromEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) { + traceCopyFromEndCalled = true + require.Equal(t, "foo", ctx.Value("fromTraceCopyFromStart")) + require.Equal(t, `COPY 2`, data.CommandTag.String()) + require.NoError(t, data.Err) + } + + _, err := conn.Exec(context.Background(), `create temporary table foo(a int4)`) + require.NoError(t, err) + + inputRows := [][]any{ + {int32(1)}, + {nil}, + } + + copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows)) + require.NoError(t, err) + require.EqualValues(t, len(inputRows), copyCount) + require.True(t, traceCopyFromStartCalled) + require.True(t, traceCopyFromEndCalled) + }) +} + +func TestTracePrepare(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + tracePrepareStartCalled := false + tracer.tracePrepareStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareStartData) context.Context { + tracePrepareStartCalled = true + require.Equal(t, `ps`, data.Name) + require.Equal(t, `select $1::text`, data.SQL) + return context.WithValue(ctx, "fromTracePrepareStart", "foo") + } + + tracePrepareEndCalled := false + tracer.tracePrepareEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) { + tracePrepareEndCalled = true + require.False(t, data.AlreadyPrepared) + require.NoError(t, data.Err) + } + + _, err := conn.Prepare(ctx, "ps", `select $1::text`) + require.NoError(t, err) + require.True(t, tracePrepareStartCalled) + require.True(t, tracePrepareEndCalled) + + tracePrepareStartCalled = false + tracePrepareEndCalled = false + tracer.tracePrepareEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) { + tracePrepareEndCalled = true + require.True(t, data.AlreadyPrepared) + require.NoError(t, data.Err) + } + + _, err = conn.Prepare(ctx, "ps", `select $1::text`) + require.NoError(t, err) + require.True(t, tracePrepareStartCalled) + require.True(t, tracePrepareEndCalled) + }) +} + +func TestTraceConnect(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + config := defaultConnTestRunner.CreateConfig(context.Background(), t) + config.Tracer = tracer + + traceConnectStartCalled := false + tracer.traceConnectStart = func(ctx context.Context, data pgx.TraceConnectStartData) context.Context { + traceConnectStartCalled = true + require.NotNil(t, data.ConnConfig) + return context.WithValue(ctx, "fromTraceConnectStart", "foo") + } + + traceConnectEndCalled := false + tracer.traceConnectEnd = func(ctx context.Context, data pgx.TraceConnectEndData) { + traceConnectEndCalled = true + require.NotNil(t, data.Conn) + require.NoError(t, data.Err) + } + + conn1, err := pgx.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer conn1.Close(context.Background()) + require.True(t, traceConnectStartCalled) + require.True(t, traceConnectEndCalled) + + config, err = pgx.ParseConfig("host=/invalid") + require.NoError(t, err) + config.Tracer = tracer + + traceConnectStartCalled = false + traceConnectEndCalled = false + tracer.traceConnectEnd = func(ctx context.Context, data pgx.TraceConnectEndData) { + traceConnectEndCalled = true + require.Nil(t, data.Conn) + require.Error(t, data.Err) + } + + conn2, err := pgx.ConnectConfig(context.Background(), config) + require.Nil(t, conn2) + require.Error(t, err) + require.True(t, traceConnectStartCalled) + require.True(t, traceConnectEndCalled) +} diff --git a/tx.go b/tx.go index 2914ada7..24daf0f8 100644 --- a/tx.go +++ b/tx.go @@ -7,7 +7,7 @@ import ( "fmt" "strconv" - "github.com/jackc/pgconn" + "github.com/jackc/pgx/v5/pgconn" ) // TxIsoLevel is the transaction isolation level (serializable, repeatable read, read committed or read uncommitted) @@ -94,39 +94,6 @@ func (c *Conn) BeginTx(ctx context.Context, txOptions TxOptions) (Tx, error) { return &dbTx{conn: c}, nil } -// BeginFunc starts a transaction and calls f. If f does not return an error the transaction is committed. If f returns -// an error the transaction is rolled back. The context will be used when executing the transaction control statements -// (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect the execution of f. -func (c *Conn) BeginFunc(ctx context.Context, f func(Tx) error) (err error) { - return c.BeginTxFunc(ctx, TxOptions{}, f) -} - -// BeginTxFunc starts a transaction with txOptions determining the transaction mode and calls f. If f does not return -// an error the transaction is committed. If f returns an error the transaction is rolled back. The context will be -// used when executing the transaction control statements (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect -// the execution of f. -func (c *Conn) BeginTxFunc(ctx context.Context, txOptions TxOptions, f func(Tx) error) (err error) { - var tx Tx - tx, err = c.BeginTx(ctx, txOptions) - if err != nil { - return err - } - defer func() { - rollbackErr := tx.Rollback(ctx) - if rollbackErr != nil && !errors.Is(rollbackErr, ErrTxClosed) { - err = rollbackErr - } - }() - - fErr := f(tx) - if fErr != nil { - _ = tx.Rollback(ctx) // ignore rollback error as there is already an error to return - return fErr - } - - return tx.Commit(ctx) -} - // Tx represents a database transaction. // // Tx is an interface instead of a struct to enable connection pools to be implemented without relying on internal pgx @@ -138,20 +105,17 @@ type Tx interface { // Begin starts a pseudo nested transaction. Begin(ctx context.Context) (Tx, error) - // BeginFunc starts a pseudo nested transaction and executes f. If f does not return an err the pseudo nested - // transaction will be committed. If it does then it will be rolled back. - BeginFunc(ctx context.Context, f func(Tx) error) (err error) - // Commit commits the transaction if this is a real transaction or releases the savepoint if this is a pseudo nested - // transaction. Commit will return ErrTxClosed if the Tx is already closed, but is otherwise safe to call multiple - // times. If the commit fails with a rollback status (e.g. the transaction was already in a broken state) then - // ErrTxCommitRollback will be returned. + // transaction. Commit will return an error where errors.Is(ErrTxClosed) is true if the Tx is already closed, but is + // otherwise safe to call multiple times. If the commit fails with a rollback status (e.g. the transaction was already + // in a broken state) then an error where errors.Is(ErrTxCommitRollback) is true will be returned. Commit(ctx context.Context) error // Rollback rolls back the transaction if this is a real transaction or rolls back to the savepoint if this is a - // pseudo nested transaction. Rollback will return ErrTxClosed if the Tx is already closed, but is otherwise safe to - // call multiple times. Hence, a defer tx.Rollback() is safe even if tx.Commit() will be called first in a non-error - // condition. Any other failure of a real transaction will result in the connection being closed. + // pseudo nested transaction. Rollback will return an error where errors.Is(ErrTxClosed) is true if the Tx is already + // closed, but is otherwise safe to call multiple times. Hence, a defer tx.Rollback() is safe even if tx.Commit() will + // be called first in a non-error condition. Any other failure of a real transaction will result in the connection + // being closed. Rollback(ctx context.Context) error CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) @@ -160,10 +124,9 @@ type Tx interface { Prepare(ctx context.Context, name, sql string) (*pgconn.StatementDescription, error) - Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) - Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) - QueryRow(ctx context.Context, sql string, args ...interface{}) Row - QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) + Exec(ctx context.Context, sql string, arguments ...any) (commandTag pgconn.CommandTag, err error) + Query(ctx context.Context, sql string, args ...any) (Rows, error) + QueryRow(ctx context.Context, sql string, args ...any) Row // Conn returns the underlying *Conn that on which this transaction is executing. Conn() *Conn @@ -195,32 +158,6 @@ func (tx *dbTx) Begin(ctx context.Context) (Tx, error) { return &dbSimulatedNestedTx{tx: tx, savepointNum: tx.savepointNum}, nil } -func (tx *dbTx) BeginFunc(ctx context.Context, f func(Tx) error) (err error) { - if tx.closed { - return ErrTxClosed - } - - var savepoint Tx - savepoint, err = tx.Begin(ctx) - if err != nil { - return err - } - defer func() { - rollbackErr := savepoint.Rollback(ctx) - if rollbackErr != nil && !errors.Is(rollbackErr, ErrTxClosed) { - err = rollbackErr - } - }() - - fErr := f(savepoint) - if fErr != nil { - _ = savepoint.Rollback(ctx) // ignore rollback error as there is already an error to return - return fErr - } - - return savepoint.Commit(ctx) -} - // Commit commits the transaction. func (tx *dbTx) Commit(ctx context.Context) error { if tx.closed { @@ -235,7 +172,7 @@ func (tx *dbTx) Commit(ctx context.Context) error { } return err } - if string(commandTag) == "ROLLBACK" { + if commandTag.String() == "ROLLBACK" { return ErrTxCommitRollback } @@ -263,7 +200,7 @@ func (tx *dbTx) Rollback(ctx context.Context) error { } // Exec delegates to the underlying *Conn -func (tx *dbTx) Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) { +func (tx *dbTx) Exec(ctx context.Context, sql string, arguments ...any) (commandTag pgconn.CommandTag, err error) { return tx.conn.Exec(ctx, sql, arguments...) } @@ -277,29 +214,20 @@ func (tx *dbTx) Prepare(ctx context.Context, name, sql string) (*pgconn.Statemen } // Query delegates to the underlying *Conn -func (tx *dbTx) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) { +func (tx *dbTx) Query(ctx context.Context, sql string, args ...any) (Rows, error) { if tx.closed { // Because checking for errors can be deferred to the *Rows, build one with the error err := ErrTxClosed - return &connRows{closed: true, err: err}, err + return &baseRows{closed: true, err: err}, err } return tx.conn.Query(ctx, sql, args...) } // QueryRow delegates to the underlying *Conn -func (tx *dbTx) QueryRow(ctx context.Context, sql string, args ...interface{}) Row { +func (tx *dbTx) QueryRow(ctx context.Context, sql string, args ...any) Row { rows, _ := tx.Query(ctx, sql, args...) - return (*connRow)(rows.(*connRows)) -} - -// QueryFunc delegates to the underlying *Conn. -func (tx *dbTx) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { - if tx.closed { - return nil, ErrTxClosed - } - - return tx.conn.QueryFunc(ctx, sql, args, scans, f) + return (*connRow)(rows.(*baseRows)) } // CopyFrom delegates to the underlying *Conn @@ -345,14 +273,6 @@ func (sp *dbSimulatedNestedTx) Begin(ctx context.Context) (Tx, error) { return sp.tx.Begin(ctx) } -func (sp *dbSimulatedNestedTx) BeginFunc(ctx context.Context, f func(Tx) error) (err error) { - if sp.closed { - return ErrTxClosed - } - - return sp.tx.BeginFunc(ctx, f) -} - // Commit releases the savepoint essentially committing the pseudo nested transaction. func (sp *dbSimulatedNestedTx) Commit(ctx context.Context) error { if sp.closed { @@ -378,9 +298,9 @@ func (sp *dbSimulatedNestedTx) Rollback(ctx context.Context) error { } // Exec delegates to the underlying Tx -func (sp *dbSimulatedNestedTx) Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) { +func (sp *dbSimulatedNestedTx) Exec(ctx context.Context, sql string, arguments ...any) (commandTag pgconn.CommandTag, err error) { if sp.closed { - return nil, ErrTxClosed + return pgconn.CommandTag{}, ErrTxClosed } return sp.tx.Exec(ctx, sql, arguments...) @@ -396,29 +316,20 @@ func (sp *dbSimulatedNestedTx) Prepare(ctx context.Context, name, sql string) (* } // Query delegates to the underlying Tx -func (sp *dbSimulatedNestedTx) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) { +func (sp *dbSimulatedNestedTx) Query(ctx context.Context, sql string, args ...any) (Rows, error) { if sp.closed { // Because checking for errors can be deferred to the *Rows, build one with the error err := ErrTxClosed - return &connRows{closed: true, err: err}, err + return &baseRows{closed: true, err: err}, err } return sp.tx.Query(ctx, sql, args...) } // QueryRow delegates to the underlying Tx -func (sp *dbSimulatedNestedTx) QueryRow(ctx context.Context, sql string, args ...interface{}) Row { +func (sp *dbSimulatedNestedTx) QueryRow(ctx context.Context, sql string, args ...any) Row { rows, _ := sp.Query(ctx, sql, args...) - return (*connRow)(rows.(*connRows)) -} - -// QueryFunc delegates to the underlying Tx. -func (sp *dbSimulatedNestedTx) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { - if sp.closed { - return nil, ErrTxClosed - } - - return sp.tx.QueryFunc(ctx, sql, args, scans, f) + return (*connRow)(rows.(*baseRows)) } // CopyFrom delegates to the underlying *Conn @@ -446,3 +357,59 @@ func (sp *dbSimulatedNestedTx) LargeObjects() LargeObjects { func (sp *dbSimulatedNestedTx) Conn() *Conn { return sp.tx.Conn() } + +// BeginFunc calls Begin on db and then calls fn. If fn does not return an error then it calls Commit on db. If fn +// returns an error it calls Rollback on db. The context will be used when executing the transaction control statements +// (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect the execution of fn. +func BeginFunc( + ctx context.Context, + db interface { + Begin(ctx context.Context) (Tx, error) + }, + fn func(Tx) error, +) (err error) { + var tx Tx + tx, err = db.Begin(ctx) + if err != nil { + return err + } + + return beginFuncExec(ctx, tx, fn) +} + +// BeginTxFunc calls BeginTx on db and then calls fn. If fn does not return an error then it calls Commit on db. If fn +// returns an error it calls Rollback on db. The context will be used when executing the transaction control statements +// (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect the execution of fn. +func BeginTxFunc( + ctx context.Context, + db interface { + BeginTx(ctx context.Context, txOptions TxOptions) (Tx, error) + }, + txOptions TxOptions, + fn func(Tx) error, +) (err error) { + var tx Tx + tx, err = db.BeginTx(ctx, txOptions) + if err != nil { + return err + } + + return beginFuncExec(ctx, tx, fn) +} + +func beginFuncExec(ctx context.Context, tx Tx, fn func(Tx) error) (err error) { + defer func() { + rollbackErr := tx.Rollback(ctx) + if rollbackErr != nil && !errors.Is(rollbackErr, ErrTxClosed) { + err = rollbackErr + } + }() + + fErr := fn(tx) + if fErr != nil { + _ = tx.Rollback(ctx) // ignore rollback error as there is already an error to return + return fErr + } + + return tx.Commit(ctx) +} diff --git a/tx_test.go b/tx_test.go index e9830d32..9c1c70d3 100644 --- a/tx_test.go +++ b/tx_test.go @@ -7,8 +7,9 @@ import ( "testing" "time" - "github.com/jackc/pgconn" - "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxtest" "github.com/stretchr/testify/require" ) @@ -106,7 +107,7 @@ func TestTxCommitWhenDeferredConstraintFailure(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - skipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") + pgxtest.SkipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") createSql := ` create temporary table foo( @@ -273,7 +274,7 @@ func TestBeginIsoLevels(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - skipCockroachDB(t, conn, "Server always uses SERIALIZABLE isolation (https://www.cockroachlabs.com/docs/stable/demo-serializable.html)") + pgxtest.SkipCockroachDB(t, conn, "Server always uses SERIALIZABLE isolation (https://www.cockroachlabs.com/docs/stable/demo-serializable.html)") isoLevels := []pgx.TxIsoLevel{pgx.Serializable, pgx.RepeatableRead, pgx.ReadCommitted, pgx.ReadUncommitted} for _, iso := range isoLevels { @@ -311,7 +312,7 @@ func TestBeginFunc(t *testing.T) { _, err := conn.Exec(context.Background(), createSql) require.NoError(t, err) - err = conn.BeginFunc(context.Background(), func(tx pgx.Tx) error { + err = pgx.BeginFunc(context.Background(), conn, func(tx pgx.Tx) error { _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)") require.NoError(t, err) return nil @@ -340,7 +341,7 @@ func TestBeginFuncRollbackOnError(t *testing.T) { _, err := conn.Exec(context.Background(), createSql) require.NoError(t, err) - err = conn.BeginFunc(context.Background(), func(tx pgx.Tx) error { + err = pgx.BeginFunc(context.Background(), conn, func(tx pgx.Tx) error { _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)") require.NoError(t, err) return errors.New("some error") @@ -521,15 +522,15 @@ func TestTxBeginFuncNestedTransactionCommit(t *testing.T) { _, err := db.Exec(context.Background(), createSql) require.NoError(t, err) - err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { + err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error { _, err := db.Exec(context.Background(), "insert into foo(id) values (1)") require.NoError(t, err) - err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { + err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error { _, err := db.Exec(context.Background(), "insert into foo(id) values (2)") require.NoError(t, err) - err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { + err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error { _, err := db.Exec(context.Background(), "insert into foo(id) values (3)") require.NoError(t, err) return nil @@ -564,11 +565,11 @@ func TestTxBeginFuncNestedTransactionRollback(t *testing.T) { _, err := db.Exec(context.Background(), createSql) require.NoError(t, err) - err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { + err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error { _, err := db.Exec(context.Background(), "insert into foo(id) values (1)") require.NoError(t, err) - err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { + err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error { _, err := db.Exec(context.Background(), "insert into foo(id) values (2)") require.NoError(t, err) return errors.New("do a rollback") diff --git a/values.go b/values.go index 1a945475..19c642fa 100644 --- a/values.go +++ b/values.go @@ -1,14 +1,11 @@ package pgx import ( - "database/sql/driver" - "fmt" - "math" - "reflect" - "time" + "errors" - "github.com/jackc/pgio" - "github.com/jackc/pgtype" + "github.com/jackc/pgx/v5/internal/anynil" + "github.com/jackc/pgx/v5/internal/pgio" + "github.com/jackc/pgx/v5/pgtype" ) // PostgreSQL format codes @@ -17,264 +14,55 @@ const ( BinaryFormatCode = 1 ) -// SerializationError occurs on failure to encode or decode a value -type SerializationError string - -func (e SerializationError) Error() string { - return string(e) -} - -func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, error) { - if arg == nil { +func convertSimpleArgument(m *pgtype.Map, arg any) (any, error) { + if anynil.Is(arg) { return nil, nil } - refVal := reflect.ValueOf(arg) - if refVal.Kind() == reflect.Ptr && refVal.IsNil() { + buf, err := m.Encode(0, TextFormatCode, arg, []byte{}) + if err != nil { + return nil, err + } + if buf == nil { return nil, nil } - - switch arg := arg.(type) { - - // https://github.com/jackc/pgx/issues/409 Changed JSON and JSONB to surface - // []byte to database/sql instead of string. But that caused problems with the - // simple protocol because the driver.Valuer case got taken before the - // pgtype.TextEncoder case. And driver.Valuer needed to be first in the usual - // case because of https://github.com/jackc/pgx/issues/339. So instead we - // special case JSON and JSONB. - case *pgtype.JSON: - buf, err := arg.EncodeText(ci, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - return string(buf), nil - case *pgtype.JSONB: - buf, err := arg.EncodeText(ci, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - return string(buf), nil - - case driver.Valuer: - return callValuerValue(arg) - case pgtype.TextEncoder: - buf, err := arg.EncodeText(ci, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - return string(buf), nil - case float32: - return float64(arg), nil - case float64: - return arg, nil - case bool: - return arg, nil - case time.Duration: - return fmt.Sprintf("%d microsecond", int64(arg)/1000), nil - case time.Time: - return arg, nil - case string: - return arg, nil - case []byte: - return arg, nil - case int8: - return int64(arg), nil - case int16: - return int64(arg), nil - case int32: - return int64(arg), nil - case int64: - return arg, nil - case int: - return int64(arg), nil - case uint8: - return int64(arg), nil - case uint16: - return int64(arg), nil - case uint32: - return int64(arg), nil - case uint64: - if arg > math.MaxInt64 { - return nil, fmt.Errorf("arg too big for int64: %v", arg) - } - return int64(arg), nil - case uint: - if uint64(arg) > math.MaxInt64 { - return nil, fmt.Errorf("arg too big for int64: %v", arg) - } - return int64(arg), nil - } - - if dt, found := ci.DataTypeForValue(arg); found { - v := dt.Value - err := v.Set(arg) - if err != nil { - return nil, err - } - buf, err := v.(pgtype.TextEncoder).EncodeText(ci, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - return string(buf), nil - } - - if refVal.Kind() == reflect.Ptr { - arg = refVal.Elem().Interface() - return convertSimpleArgument(ci, arg) - } - - if strippedArg, ok := stripNamedType(&refVal); ok { - return convertSimpleArgument(ci, strippedArg) - } - return nil, SerializationError(fmt.Sprintf("Cannot encode %T in simple protocol - %T must implement driver.Valuer, pgtype.TextEncoder, or be a native type", arg, arg)) + return string(buf), nil } -func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid uint32, arg interface{}) ([]byte, error) { - if arg == nil { +func encodeCopyValue(m *pgtype.Map, buf []byte, oid uint32, arg any) ([]byte, error) { + if anynil.Is(arg) { return pgio.AppendInt32(buf, -1), nil } - switch arg := arg.(type) { - case pgtype.BinaryEncoder: - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - argBuf, err := arg.EncodeBinary(ci, buf) - if err != nil { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + argBuf, err := m.Encode(oid, BinaryFormatCode, arg, buf) + if err != nil { + if argBuf2, err2 := tryScanStringCopyValueThenEncode(m, buf, oid, arg); err2 == nil { + argBuf = argBuf2 + } else { return nil, err } - if argBuf != nil { - buf = argBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - return buf, nil - case pgtype.TextEncoder: - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - argBuf, err := arg.EncodeText(ci, buf) - if err != nil { - return nil, err - } - if argBuf != nil { - buf = argBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - return buf, nil - case string: - buf = pgio.AppendInt32(buf, int32(len(arg))) - buf = append(buf, arg...) - return buf, nil } - refVal := reflect.ValueOf(arg) - - if refVal.Kind() == reflect.Ptr { - if refVal.IsNil() { - return pgio.AppendInt32(buf, -1), nil - } - arg = refVal.Elem().Interface() - return encodePreparedStatementArgument(ci, buf, oid, arg) + if argBuf != nil { + buf = argBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } - - if dt, ok := ci.DataTypeForOID(oid); ok { - value := dt.Value - err := value.Set(arg) - if err != nil { - { - if arg, ok := arg.(driver.Valuer); ok { - v, err := callValuerValue(arg) - if err != nil { - return nil, err - } - return encodePreparedStatementArgument(ci, buf, oid, v) - } - } - - return nil, err - } - - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - argBuf, err := value.(pgtype.BinaryEncoder).EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if argBuf != nil { - buf = argBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - return buf, nil - } - - if strippedArg, ok := stripNamedType(&refVal); ok { - return encodePreparedStatementArgument(ci, buf, oid, strippedArg) - } - return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) + return buf, nil } -// chooseParameterFormatCode determines the correct format code for an -// argument to a prepared statement. It defaults to TextFormatCode if no -// determination can be made. -func chooseParameterFormatCode(ci *pgtype.ConnInfo, oid uint32, arg interface{}) int16 { - switch arg := arg.(type) { - case pgtype.ParamFormatPreferrer: - return arg.PreferredParamFormat() - case pgtype.BinaryEncoder: - return BinaryFormatCode - case string, *string, pgtype.TextEncoder: - return TextFormatCode +func tryScanStringCopyValueThenEncode(m *pgtype.Map, buf []byte, oid uint32, arg any) ([]byte, error) { + s, ok := arg.(string) + if !ok { + return nil, errors.New("not a string") } - return ci.ParamFormatCodeForOID(oid) -} - -func stripNamedType(val *reflect.Value) (interface{}, bool) { - switch val.Kind() { - case reflect.Int: - convVal := int(val.Int()) - return convVal, reflect.TypeOf(convVal) != val.Type() - case reflect.Int8: - convVal := int8(val.Int()) - return convVal, reflect.TypeOf(convVal) != val.Type() - case reflect.Int16: - convVal := int16(val.Int()) - return convVal, reflect.TypeOf(convVal) != val.Type() - case reflect.Int32: - convVal := int32(val.Int()) - return convVal, reflect.TypeOf(convVal) != val.Type() - case reflect.Int64: - convVal := int64(val.Int()) - return convVal, reflect.TypeOf(convVal) != val.Type() - case reflect.Uint: - convVal := uint(val.Uint()) - return convVal, reflect.TypeOf(convVal) != val.Type() - case reflect.Uint8: - convVal := uint8(val.Uint()) - return convVal, reflect.TypeOf(convVal) != val.Type() - case reflect.Uint16: - convVal := uint16(val.Uint()) - return convVal, reflect.TypeOf(convVal) != val.Type() - case reflect.Uint32: - convVal := uint32(val.Uint()) - return convVal, reflect.TypeOf(convVal) != val.Type() - case reflect.Uint64: - convVal := uint64(val.Uint()) - return convVal, reflect.TypeOf(convVal) != val.Type() - case reflect.String: - convVal := val.String() - return convVal, reflect.TypeOf(convVal) != val.Type() + var v any + err := m.Scan(oid, TextFormatCode, []byte(s), &v) + if err != nil { + return nil, err } - return nil, false + return m.Encode(oid, BinaryFormatCode, v, buf) } diff --git a/values_test.go b/values_test.go index 6ae6c8a0..39bf1ead 100644 --- a/values_test.go +++ b/values_test.go @@ -10,7 +10,8 @@ import ( "testing" "time" - "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxtest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -18,7 +19,7 @@ import ( func TestDateTranscode(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { dates := []time.Time{ time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(1000, 1, 1, 0, 0, 0, 0, time.UTC), @@ -57,7 +58,7 @@ func TestDateTranscode(t *testing.T) { func TestTimestampTzTranscode(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { inputTime := time.Date(2013, 1, 2, 3, 4, 5, 6000, time.Local) var outputTime time.Time @@ -77,9 +78,9 @@ func TestTimestampTzTranscode(t *testing.T) { func TestJSONAndJSONBTranscode(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { for _, typename := range []string{"json", "jsonb"} { - if _, ok := conn.ConnInfo().DataTypeForName(typename); !ok { + if _, ok := conn.TypeMap().TypeForName(typename); !ok { continue // No JSON/JSONB type -- must be running against old PostgreSQL } @@ -96,7 +97,7 @@ func TestJSONAndJSONBTranscodeExtendedOnly(t *testing.T) { defer closeConn(t, conn) for _, typename := range []string{"json", "jsonb"} { - if _, ok := conn.ConnInfo().DataTypeForName(typename); !ok { + if _, ok := conn.TypeMap().TypeForName(typename); !ok { continue // No JSON/JSONB type -- must be running against old PostgreSQL } testJSONSingleLevelStringMap(t, conn, typename) @@ -109,7 +110,7 @@ func TestJSONAndJSONBTranscodeExtendedOnly(t *testing.T) { } -func testJSONString(t *testing.T, conn *pgx.Conn, typename string) { +func testJSONString(t testing.TB, conn *pgx.Conn, typename string) { input := `{"key": "value"}` expectedOutput := map[string]string{"key": "value"} var output map[string]string @@ -125,7 +126,7 @@ func testJSONString(t *testing.T, conn *pgx.Conn, typename string) { } } -func testJSONStringPointer(t *testing.T, conn *pgx.Conn, typename string) { +func testJSONStringPointer(t testing.TB, conn *pgx.Conn, typename string) { input := `{"key": "value"}` expectedOutput := map[string]string{"key": "value"} var output map[string]string @@ -157,12 +158,12 @@ func testJSONSingleLevelStringMap(t *testing.T, conn *pgx.Conn, typename string) } func testJSONNestedMap(t *testing.T, conn *pgx.Conn, typename string) { - input := map[string]interface{}{ + input := map[string]any{ "name": "Uncanny", - "stats": map[string]interface{}{"hp": float64(107), "maxhp": float64(150)}, - "inventory": []interface{}{"phone", "key"}, + "stats": map[string]any{"hp": float64(107), "maxhp": float64(150)}, + "inventory": []any{"phone", "key"}, } - var output map[string]interface{} + var output map[string]any err := conn.QueryRow(context.Background(), "select $1::"+typename, input).Scan(&output) if err != nil { t.Errorf("%s: QueryRow Scan failed: %v", typename, err) @@ -170,7 +171,7 @@ func testJSONNestedMap(t *testing.T, conn *pgx.Conn, typename string) { } if !reflect.DeepEqual(input, output) { - t.Errorf("%s: Did not transcode map[string]interface{} successfully: %v is not %v", typename, input, output) + t.Errorf("%s: Did not transcode map[string]any successfully: %v is not %v", typename, input, output) return } } @@ -233,7 +234,7 @@ func testJSONStruct(t *testing.T, conn *pgx.Conn, typename string) { } } -func mustParseCIDR(t *testing.T, s string) *net.IPNet { +func mustParseCIDR(t testing.TB, s string) *net.IPNet { _, ipnet, err := net.ParseCIDR(s) if err != nil { t.Fatal(err) @@ -242,35 +243,10 @@ func mustParseCIDR(t *testing.T, s string) *net.IPNet { return ipnet } -func TestStringToNotTextTypeTranscode(t *testing.T) { - t.Parallel() - - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { - input := "01086ee0-4963-4e35-9116-30c173a8d0bd" - - var output string - err := conn.QueryRow(context.Background(), "select $1::uuid", input).Scan(&output) - if err != nil { - t.Fatal(err) - } - if input != output { - t.Errorf("uuid: Did not transcode string successfully: %s is not %s", input, output) - } - - err = conn.QueryRow(context.Background(), "select $1::uuid", &input).Scan(&output) - if err != nil { - t.Fatal(err) - } - if input != output { - t.Errorf("uuid: Did not transcode pointer to string successfully: %s is not %s", input, output) - } - }) -} - func TestInetCIDRTranscodeIPNet(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { tests := []struct { sql string value *net.IPNet @@ -321,7 +297,7 @@ func TestInetCIDRTranscodeIPNet(t *testing.T) { func TestInetCIDRTranscodeIP(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { tests := []struct { sql string value net.IP @@ -385,7 +361,7 @@ func TestInetCIDRTranscodeIP(t *testing.T) { func TestInetCIDRArrayTranscodeIPNet(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { tests := []struct { sql string value []*net.IPNet @@ -448,7 +424,7 @@ func TestInetCIDRArrayTranscodeIPNet(t *testing.T) { func TestInetCIDRArrayTranscodeIP(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { tests := []struct { sql string value []net.IP @@ -534,7 +510,7 @@ func TestInetCIDRArrayTranscodeIP(t *testing.T) { func TestInetCIDRTranscodeWithJustIP(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { tests := []struct { sql string value string @@ -580,16 +556,16 @@ func TestInetCIDRTranscodeWithJustIP(t *testing.T) { func TestArrayDecoding(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { tests := []struct { sql string - query interface{} - scan interface{} - assert func(*testing.T, interface{}, interface{}) + query any + scan any + assert func(testing.TB, any, any) }{ { "select $1::bool[]", []bool{true, false, true}, &[]bool{}, - func(t *testing.T, query, scan interface{}) { + func(t testing.TB, query, scan any) { if !reflect.DeepEqual(query, *(scan.(*[]bool))) { t.Errorf("failed to encode bool[]") } @@ -597,7 +573,7 @@ func TestArrayDecoding(t *testing.T) { }, { "select $1::smallint[]", []int16{2, 4, 484, 32767}, &[]int16{}, - func(t *testing.T, query, scan interface{}) { + func(t testing.TB, query, scan any) { if !reflect.DeepEqual(query, *(scan.(*[]int16))) { t.Errorf("failed to encode smallint[]") } @@ -605,7 +581,7 @@ func TestArrayDecoding(t *testing.T) { }, { "select $1::smallint[]", []uint16{2, 4, 484, 32767}, &[]uint16{}, - func(t *testing.T, query, scan interface{}) { + func(t testing.TB, query, scan any) { if !reflect.DeepEqual(query, *(scan.(*[]uint16))) { t.Errorf("failed to encode smallint[]") } @@ -613,7 +589,7 @@ func TestArrayDecoding(t *testing.T) { }, { "select $1::int[]", []int32{2, 4, 484}, &[]int32{}, - func(t *testing.T, query, scan interface{}) { + func(t testing.TB, query, scan any) { if !reflect.DeepEqual(query, *(scan.(*[]int32))) { t.Errorf("failed to encode int[]") } @@ -621,7 +597,7 @@ func TestArrayDecoding(t *testing.T) { }, { "select $1::int[]", []uint32{2, 4, 484, 2147483647}, &[]uint32{}, - func(t *testing.T, query, scan interface{}) { + func(t testing.TB, query, scan any) { if !reflect.DeepEqual(query, *(scan.(*[]uint32))) { t.Errorf("failed to encode int[]") } @@ -629,7 +605,7 @@ func TestArrayDecoding(t *testing.T) { }, { "select $1::bigint[]", []int64{2, 4, 484, 9223372036854775807}, &[]int64{}, - func(t *testing.T, query, scan interface{}) { + func(t testing.TB, query, scan any) { if !reflect.DeepEqual(query, *(scan.(*[]int64))) { t.Errorf("failed to encode bigint[]") } @@ -637,7 +613,7 @@ func TestArrayDecoding(t *testing.T) { }, { "select $1::bigint[]", []uint64{2, 4, 484, 9223372036854775807}, &[]uint64{}, - func(t *testing.T, query, scan interface{}) { + func(t testing.TB, query, scan any) { if !reflect.DeepEqual(query, *(scan.(*[]uint64))) { t.Errorf("failed to encode bigint[]") } @@ -645,7 +621,7 @@ func TestArrayDecoding(t *testing.T) { }, { "select $1::text[]", []string{"it's", "over", "9000!"}, &[]string{}, - func(t *testing.T, query, scan interface{}) { + func(t testing.TB, query, scan any) { if !reflect.DeepEqual(query, *(scan.(*[]string))) { t.Errorf("failed to encode text[]") } @@ -653,7 +629,7 @@ func TestArrayDecoding(t *testing.T) { }, { "select $1::timestamptz[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{}, - func(t *testing.T, query, scan interface{}) { + func(t testing.TB, query, scan any) { queryTimeSlice := query.([]time.Time) scanTimeSlice := *(scan.(*[]time.Time)) require.Equal(t, len(queryTimeSlice), len(scanTimeSlice)) @@ -664,7 +640,7 @@ func TestArrayDecoding(t *testing.T) { }, { "select $1::bytea[]", [][]byte{{0, 1, 2, 3}, {4, 5, 6, 7}}, &[][]byte{}, - func(t *testing.T, query, scan interface{}) { + func(t testing.TB, query, scan any) { queryBytesSliceSlice := query.([][]byte) scanBytesSliceSlice := *(scan.(*[][]byte)) if len(queryBytesSliceSlice) != len(scanBytesSliceSlice) { @@ -696,7 +672,7 @@ func TestArrayDecoding(t *testing.T) { func TestEmptyArrayDecoding(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { var val []string err := conn.QueryRow(context.Background(), "select array[]::text[]").Scan(&val) @@ -741,8 +717,8 @@ func TestEmptyArrayDecoding(t *testing.T) { func TestPointerPointer(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { - skipCockroachDB(t, conn, "Server auto converts ints to bigint and test relies on exact types") + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "Server auto converts ints to bigint and test relies on exact types") type allTypes struct { s *string @@ -778,26 +754,26 @@ func TestPointerPointer(t *testing.T) { tests := []struct { sql string - queryArgs []interface{} - scanArgs []interface{} + queryArgs []any + scanArgs []any expected allTypes }{ - {"select $1::text", []interface{}{expected.s}, []interface{}{&actual.s}, allTypes{s: expected.s}}, - {"select $1::text", []interface{}{zero.s}, []interface{}{&actual.s}, allTypes{}}, - {"select $1::int2", []interface{}{expected.i16}, []interface{}{&actual.i16}, allTypes{i16: expected.i16}}, - {"select $1::int2", []interface{}{zero.i16}, []interface{}{&actual.i16}, allTypes{}}, - {"select $1::int4", []interface{}{expected.i32}, []interface{}{&actual.i32}, allTypes{i32: expected.i32}}, - {"select $1::int4", []interface{}{zero.i32}, []interface{}{&actual.i32}, allTypes{}}, - {"select $1::int8", []interface{}{expected.i64}, []interface{}{&actual.i64}, allTypes{i64: expected.i64}}, - {"select $1::int8", []interface{}{zero.i64}, []interface{}{&actual.i64}, allTypes{}}, - {"select $1::float4", []interface{}{expected.f32}, []interface{}{&actual.f32}, allTypes{f32: expected.f32}}, - {"select $1::float4", []interface{}{zero.f32}, []interface{}{&actual.f32}, allTypes{}}, - {"select $1::float8", []interface{}{expected.f64}, []interface{}{&actual.f64}, allTypes{f64: expected.f64}}, - {"select $1::float8", []interface{}{zero.f64}, []interface{}{&actual.f64}, allTypes{}}, - {"select $1::bool", []interface{}{expected.b}, []interface{}{&actual.b}, allTypes{b: expected.b}}, - {"select $1::bool", []interface{}{zero.b}, []interface{}{&actual.b}, allTypes{}}, - {"select $1::timestamptz", []interface{}{expected.t}, []interface{}{&actual.t}, allTypes{t: expected.t}}, - {"select $1::timestamptz", []interface{}{zero.t}, []interface{}{&actual.t}, allTypes{}}, + {"select $1::text", []any{expected.s}, []any{&actual.s}, allTypes{s: expected.s}}, + {"select $1::text", []any{zero.s}, []any{&actual.s}, allTypes{}}, + {"select $1::int2", []any{expected.i16}, []any{&actual.i16}, allTypes{i16: expected.i16}}, + {"select $1::int2", []any{zero.i16}, []any{&actual.i16}, allTypes{}}, + {"select $1::int4", []any{expected.i32}, []any{&actual.i32}, allTypes{i32: expected.i32}}, + {"select $1::int4", []any{zero.i32}, []any{&actual.i32}, allTypes{}}, + {"select $1::int8", []any{expected.i64}, []any{&actual.i64}, allTypes{i64: expected.i64}}, + {"select $1::int8", []any{zero.i64}, []any{&actual.i64}, allTypes{}}, + {"select $1::float4", []any{expected.f32}, []any{&actual.f32}, allTypes{f32: expected.f32}}, + {"select $1::float4", []any{zero.f32}, []any{&actual.f32}, allTypes{}}, + {"select $1::float8", []any{expected.f64}, []any{&actual.f64}, allTypes{f64: expected.f64}}, + {"select $1::float8", []any{zero.f64}, []any{&actual.f64}, allTypes{}}, + {"select $1::bool", []any{expected.b}, []any{&actual.b}, allTypes{b: expected.b}}, + {"select $1::bool", []any{zero.b}, []any{&actual.b}, allTypes{}}, + {"select $1::timestamptz", []any{expected.t}, []any{&actual.t}, allTypes{t: expected.t}}, + {"select $1::timestamptz", []any{zero.t}, []any{&actual.t}, allTypes{}}, } for i, tt := range tests { @@ -827,7 +803,7 @@ func TestPointerPointer(t *testing.T) { func TestPointerPointerNonZero(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { f := "foo" dest := &f @@ -844,7 +820,7 @@ func TestPointerPointerNonZero(t *testing.T) { func TestEncodeTypeRename(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { type _int int inInt := _int(1) var outInt _int @@ -889,6 +865,19 @@ func TestEncodeTypeRename(t *testing.T) { inString := _string("foo") var outString _string + // pgx.QueryExecModeExec requires all types to be registered. + conn.TypeMap().RegisterDefaultPgType(inInt, "int8") + conn.TypeMap().RegisterDefaultPgType(inInt8, "int8") + conn.TypeMap().RegisterDefaultPgType(inInt16, "int8") + conn.TypeMap().RegisterDefaultPgType(inInt32, "int8") + conn.TypeMap().RegisterDefaultPgType(inInt64, "int8") + conn.TypeMap().RegisterDefaultPgType(inUint, "int8") + conn.TypeMap().RegisterDefaultPgType(inUint8, "int8") + conn.TypeMap().RegisterDefaultPgType(inUint16, "int8") + conn.TypeMap().RegisterDefaultPgType(inUint32, "int8") + conn.TypeMap().RegisterDefaultPgType(inUint64, "int8") + conn.TypeMap().RegisterDefaultPgType(inString, "text") + err := conn.QueryRow(context.Background(), "select $1::int, $2::int, $3::int2, $4::int4, $5::int8, $6::int, $7::int, $8::int, $9::int, $10::int, $11::text", inInt, inInt8, inInt16, inInt32, inInt64, inUint, inUint8, inUint16, inUint32, inUint64, inString, ).Scan(&outInt, &outInt8, &outInt16, &outInt32, &outInt64, &outUint, &outUint8, &outUint16, &outUint32, &outUint64, &outString) @@ -942,56 +931,56 @@ func TestEncodeTypeRename(t *testing.T) { }) } -func TestRowDecodeBinary(t *testing.T) { - t.Parallel() +// func TestRowDecodeBinary(t *testing.T) { +// t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) +// conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) +// defer closeConn(t, conn) - tests := []struct { - sql string - expected []interface{} - }{ - { - "select row(1, 'cat', '2015-01-01 08:12:42-00'::timestamptz)", - []interface{}{ - int32(1), - "cat", - time.Date(2015, 1, 1, 8, 12, 42, 0, time.UTC).Local(), - }, - }, - { - "select row(100.0::float, 1.09::float)", - []interface{}{ - float64(100), - float64(1.09), - }, - }, - } +// tests := []struct { +// sql string +// expected []any +// }{ +// { +// "select row(1, 'cat', '2015-01-01 08:12:42-00'::timestamptz)", +// []any{ +// int32(1), +// "cat", +// time.Date(2015, 1, 1, 8, 12, 42, 0, time.UTC).Local(), +// }, +// }, +// { +// "select row(100.0::float, 1.09::float)", +// []any{ +// float64(100), +// float64(1.09), +// }, +// }, +// } - for i, tt := range tests { - var actual []interface{} +// for i, tt := range tests { +// var actual []any - err := conn.QueryRow(context.Background(), tt.sql).Scan(&actual) - if err != nil { - t.Errorf("%d. Unexpected failure: %v (sql -> %v)", i, err, tt.sql) - continue - } +// err := conn.QueryRow(context.Background(), tt.sql).Scan(&actual) +// if err != nil { +// t.Errorf("%d. Unexpected failure: %v (sql -> %v)", i, err, tt.sql) +// continue +// } - for j := range tt.expected { - assert.EqualValuesf(t, tt.expected[j], actual[j], "%d. [%d]", i, j) +// for j := range tt.expected { +// assert.EqualValuesf(t, tt.expected[j], actual[j], "%d. [%d]", i, j) - } +// } - ensureConnValid(t, conn) - } -} +// ensureConnValid(t, conn) +// } +// } // https://github.com/jackc/pgx/issues/810 func TestRowsScanNilThenScanValue(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { sql := `select null as a, null as b union select 1, 2 @@ -1033,6 +1022,7 @@ func TestScanIntoByteSlice(t *testing.T) { output []byte }{ {"int - text", "select 42", pgx.TextFormatCode, []byte("42")}, + {"int - binary", "select 42", pgx.BinaryFormatCode, []byte("42")}, {"text - text", "select 'hi'", pgx.TextFormatCode, []byte("hi")}, {"text - binary", "select 'hi'", pgx.BinaryFormatCode, []byte("hi")}, {"json - text", "select '{}'::json", pgx.TextFormatCode, []byte("{}")}, @@ -1047,19 +1037,4 @@ func TestScanIntoByteSlice(t *testing.T) { require.Equal(t, tt.output, buf) }) } - - // Failure cases - for _, tt := range []struct { - name string - sql string - err string - }{ - {"int binary", "select 42", "can't scan into dest[0]: cannot assign 42 into *[]uint8"}, - } { - t.Run(tt.name, func(t *testing.T) { - var buf []byte - err := conn.QueryRow(context.Background(), tt.sql, pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&buf) - require.EqualError(t, err, tt.err) - }) - } }