Merge remote-tracking branch 'upstream/master'
This commit is contained in:
@@ -9,14 +9,11 @@ on:
|
|||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
name: Test
|
name: Test
|
||||||
# Note: The TLS tests are rather finicky. It seems that openssl 3 encrypts certificates differently than older
|
runs-on: ubuntu-22.04
|
||||||
# openssl and it does it in a way Go and/or pgx ssl handling code can't handle. So stick with Ubuntu 20.04 until
|
|
||||||
# that is figured out.
|
|
||||||
runs-on: ubuntu-20.04
|
|
||||||
|
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
go-version: ["1.20", "1.21"]
|
go-version: ["1.21", "1.22"]
|
||||||
pg-version: [12, 13, 14, 15, 16, cockroachdb]
|
pg-version: [12, 13, 14, 15, 16, cockroachdb]
|
||||||
include:
|
include:
|
||||||
- pg-version: 12
|
- pg-version: 12
|
||||||
@@ -77,7 +74,7 @@ jobs:
|
|||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Set up Go ${{ matrix.go-version }}
|
- name: Set up Go ${{ matrix.go-version }}
|
||||||
uses: actions/setup-go@v4
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: ${{ matrix.go-version }}
|
go-version: ${{ matrix.go-version }}
|
||||||
|
|
||||||
@@ -109,7 +106,8 @@ jobs:
|
|||||||
git diff --exit-code
|
git diff --exit-code
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: go test -v -race ./...
|
# parallel testing is disabled because somehow parallel testing causes Github Actions to kill the runner.
|
||||||
|
run: go test -parallel=1 -race ./...
|
||||||
env:
|
env:
|
||||||
PGX_TEST_DATABASE: ${{ matrix.pgx-test-database }}
|
PGX_TEST_DATABASE: ${{ matrix.pgx-test-database }}
|
||||||
PGX_TEST_UNIX_SOCKET_CONN_STRING: ${{ matrix.pgx-test-unix-socket-conn-string }}
|
PGX_TEST_UNIX_SOCKET_CONN_STRING: ${{ matrix.pgx-test-unix-socket-conn-string }}
|
||||||
@@ -127,7 +125,7 @@ jobs:
|
|||||||
runs-on: windows-latest
|
runs-on: windows-latest
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
go-version: ["1.20", "1.21"]
|
go-version: ["1.21", "1.22"]
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Setup PostgreSQL
|
- name: Setup PostgreSQL
|
||||||
@@ -140,7 +138,7 @@ jobs:
|
|||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Set up Go ${{ matrix.go-version }}
|
- name: Set up Go ${{ matrix.go-version }}
|
||||||
uses: actions/setup-go@v4
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: ${{ matrix.go-version }}
|
go-version: ${{ matrix.go-version }}
|
||||||
|
|
||||||
@@ -152,6 +150,7 @@ jobs:
|
|||||||
shell: bash
|
shell: bash
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: go test -v -race ./...
|
# parallel testing is disabled because somehow parallel testing causes Github Actions to kill the runner.
|
||||||
|
run: go test -parallel=1 -race ./...
|
||||||
env:
|
env:
|
||||||
PGX_TEST_DATABASE: ${{ steps.postgres.outputs.connection-uri }}
|
PGX_TEST_DATABASE: ${{ steps.postgres.outputs.connection-uri }}
|
||||||
|
|||||||
@@ -1,3 +1,39 @@
|
|||||||
|
# 5.5.5 (March 9, 2024)
|
||||||
|
|
||||||
|
Use spaces instead of parentheses for SQL sanitization.
|
||||||
|
|
||||||
|
This still solves the problem of negative numbers creating a line comment, but this avoids breaking edge cases such as
|
||||||
|
`set foo to $1` where the substitution is taking place in a location where an arbitrary expression is not allowed.
|
||||||
|
|
||||||
|
# 5.5.4 (March 4, 2024)
|
||||||
|
|
||||||
|
Fix CVE-2024-27304
|
||||||
|
|
||||||
|
SQL injection can occur if an attacker can cause a single query or bind message to exceed 4 GB in size. An integer
|
||||||
|
overflow in the calculated message size can cause the one large message to be sent as multiple messages under the
|
||||||
|
attacker's control.
|
||||||
|
|
||||||
|
Thanks to Paul Gerste for reporting this issue.
|
||||||
|
|
||||||
|
* Fix behavior of CollectRows to return empty slice if Rows are empty (Felix)
|
||||||
|
* Fix simple protocol encoding of json.RawMessage
|
||||||
|
* Fix *Pipeline.getResults should close pipeline on error
|
||||||
|
* Fix panic in TryFindUnderlyingTypeScanPlan (David Kurman)
|
||||||
|
* Fix deallocation of invalidated cached statements in a transaction
|
||||||
|
* Handle invalid sslkey file
|
||||||
|
* Fix scan float4 into sql.Scanner
|
||||||
|
* Fix pgtype.Bits not making copy of data from read buffer. This would cause the data to be corrupted by future reads.
|
||||||
|
|
||||||
|
# 5.5.3 (February 3, 2024)
|
||||||
|
|
||||||
|
* Fix: prepared statement already exists
|
||||||
|
* Improve CopyFrom auto-conversion of text-ish values
|
||||||
|
* Add ltree type support (Florent Viel)
|
||||||
|
* Make some properties of Batch and QueuedQuery public (Pavlo Golub)
|
||||||
|
* Add AppendRows function (Edoardo Spadolini)
|
||||||
|
* Optimize convert UUID [16]byte to string (Kirill Malikov)
|
||||||
|
* Fix: LargeObject Read and Write of more than ~1GB at a time (Mitar)
|
||||||
|
|
||||||
# 5.5.2 (January 13, 2024)
|
# 5.5.2 (January 13, 2024)
|
||||||
|
|
||||||
* Allow NamedArgs to start with underscore
|
* Allow NamedArgs to start with underscore
|
||||||
|
|||||||
+2
-16
@@ -79,20 +79,11 @@ echo "listen_addresses = '127.0.0.1'" >> .testdb/$POSTGRESQL_DATA_DIR/postgresql
|
|||||||
echo "port = $PGPORT" >> .testdb/$POSTGRESQL_DATA_DIR/postgresql.conf
|
echo "port = $PGPORT" >> .testdb/$POSTGRESQL_DATA_DIR/postgresql.conf
|
||||||
cat testsetup/postgresql_ssl.conf >> .testdb/$POSTGRESQL_DATA_DIR/postgresql.conf
|
cat testsetup/postgresql_ssl.conf >> .testdb/$POSTGRESQL_DATA_DIR/postgresql.conf
|
||||||
cp testsetup/pg_hba.conf .testdb/$POSTGRESQL_DATA_DIR/pg_hba.conf
|
cp testsetup/pg_hba.conf .testdb/$POSTGRESQL_DATA_DIR/pg_hba.conf
|
||||||
cp testsetup/ca.cnf .testdb
|
|
||||||
cp testsetup/localhost.cnf .testdb
|
|
||||||
cp testsetup/pgx_sslcert.cnf .testdb
|
|
||||||
|
|
||||||
cd .testdb
|
cd .testdb
|
||||||
|
|
||||||
# Generate a CA public / private key pair.
|
# Generate CA, server, and encrypted client certificates.
|
||||||
openssl genrsa -out ca.key 4096
|
go run ../testsetup/generate_certs.go
|
||||||
openssl req -x509 -config ca.cnf -new -nodes -key ca.key -sha256 -days 365 -subj '/O=pgx-test-root' -out ca.pem
|
|
||||||
|
|
||||||
# Generate the certificate for localhost (the server).
|
|
||||||
openssl genrsa -out localhost.key 2048
|
|
||||||
openssl req -new -config localhost.cnf -key localhost.key -out localhost.csr
|
|
||||||
openssl x509 -req -in localhost.csr -CA ca.pem -CAkey ca.key -CAcreateserial -out localhost.crt -days 364 -sha256 -extfile localhost.cnf -extensions v3_req
|
|
||||||
|
|
||||||
# Copy certificates to server directory and set permissions.
|
# Copy certificates to server directory and set permissions.
|
||||||
cp ca.pem $POSTGRESQL_DATA_DIR/root.crt
|
cp ca.pem $POSTGRESQL_DATA_DIR/root.crt
|
||||||
@@ -100,11 +91,6 @@ cp localhost.key $POSTGRESQL_DATA_DIR/server.key
|
|||||||
chmod 600 $POSTGRESQL_DATA_DIR/server.key
|
chmod 600 $POSTGRESQL_DATA_DIR/server.key
|
||||||
cp localhost.crt $POSTGRESQL_DATA_DIR/server.crt
|
cp localhost.crt $POSTGRESQL_DATA_DIR/server.crt
|
||||||
|
|
||||||
# Generate the certificate for client authentication.
|
|
||||||
openssl genrsa -des3 -out pgx_sslcert.key -passout pass:certpw 2048
|
|
||||||
openssl req -new -config pgx_sslcert.cnf -key pgx_sslcert.key -passin pass:certpw -out pgx_sslcert.csr
|
|
||||||
openssl x509 -req -in pgx_sslcert.csr -CA ca.pem -CAkey ca.key -CAcreateserial -out pgx_sslcert.crt -days 363 -sha256 -extfile pgx_sslcert.cnf -extensions v3_req
|
|
||||||
|
|
||||||
cd ..
|
cd ..
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -120,6 +120,7 @@ pgerrcode contains constants for the PostgreSQL error codes.
|
|||||||
|
|
||||||
- [github.com/jackc/pgx-gofrs-uuid](https://github.com/jackc/pgx-gofrs-uuid)
|
- [github.com/jackc/pgx-gofrs-uuid](https://github.com/jackc/pgx-gofrs-uuid)
|
||||||
- [github.com/jackc/pgx-shopspring-decimal](https://github.com/jackc/pgx-shopspring-decimal)
|
- [github.com/jackc/pgx-shopspring-decimal](https://github.com/jackc/pgx-shopspring-decimal)
|
||||||
|
- [github.com/twpayne/pgx-geos](https://github.com/twpayne/pgx-geos) ([PostGIS](https://postgis.net/) and [GEOS](https://libgeos.org/) via [go-geos](https://github.com/twpayne/go-geos))
|
||||||
- [github.com/vgarvardt/pgx-google-uuid](https://github.com/vgarvardt/pgx-google-uuid)
|
- [github.com/vgarvardt/pgx-google-uuid](https://github.com/vgarvardt/pgx-google-uuid)
|
||||||
|
|
||||||
## Adapters for 3rd Party Tracers
|
## Adapters for 3rd Party Tracers
|
||||||
|
|||||||
@@ -10,8 +10,8 @@ import (
|
|||||||
|
|
||||||
// QueuedQuery is a query that has been queued for execution via a Batch.
|
// QueuedQuery is a query that has been queued for execution via a Batch.
|
||||||
type QueuedQuery struct {
|
type QueuedQuery struct {
|
||||||
query string
|
SQL string
|
||||||
arguments []any
|
Arguments []any
|
||||||
fn batchItemFunc
|
fn batchItemFunc
|
||||||
sd *pgconn.StatementDescription
|
sd *pgconn.StatementDescription
|
||||||
}
|
}
|
||||||
@@ -57,7 +57,7 @@ func (qq *QueuedQuery) Exec(fn func(ct pgconn.CommandTag) error) {
|
|||||||
// Batch queries are a way of bundling multiple queries together to avoid
|
// Batch queries are a way of bundling multiple queries together to avoid
|
||||||
// unnecessary network round trips. A Batch must only be sent once.
|
// unnecessary network round trips. A Batch must only be sent once.
|
||||||
type Batch struct {
|
type Batch struct {
|
||||||
queuedQueries []*QueuedQuery
|
QueuedQueries []*QueuedQuery
|
||||||
}
|
}
|
||||||
|
|
||||||
// Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement.
|
// Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement.
|
||||||
@@ -65,16 +65,16 @@ type Batch struct {
|
|||||||
// connection's DefaultQueryExecMode.
|
// connection's DefaultQueryExecMode.
|
||||||
func (b *Batch) Queue(query string, arguments ...any) *QueuedQuery {
|
func (b *Batch) Queue(query string, arguments ...any) *QueuedQuery {
|
||||||
qq := &QueuedQuery{
|
qq := &QueuedQuery{
|
||||||
query: query,
|
SQL: query,
|
||||||
arguments: arguments,
|
Arguments: arguments,
|
||||||
}
|
}
|
||||||
b.queuedQueries = append(b.queuedQueries, qq)
|
b.QueuedQueries = append(b.QueuedQueries, qq)
|
||||||
return qq
|
return qq
|
||||||
}
|
}
|
||||||
|
|
||||||
// Len returns number of queries that have been queued so far.
|
// Len returns number of queries that have been queued so far.
|
||||||
func (b *Batch) Len() int {
|
func (b *Batch) Len() int {
|
||||||
return len(b.queuedQueries)
|
return len(b.QueuedQueries)
|
||||||
}
|
}
|
||||||
|
|
||||||
type BatchResults interface {
|
type BatchResults interface {
|
||||||
@@ -227,9 +227,9 @@ func (br *batchResults) Close() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Read and run fn for all remaining items
|
// Read and run fn for all remaining items
|
||||||
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) {
|
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
|
||||||
if br.b.queuedQueries[br.qqIdx].fn != nil {
|
if br.b.QueuedQueries[br.qqIdx].fn != nil {
|
||||||
err := br.b.queuedQueries[br.qqIdx].fn(br)
|
err := br.b.QueuedQueries[br.qqIdx].fn(br)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
br.err = err
|
br.err = err
|
||||||
}
|
}
|
||||||
@@ -253,10 +253,10 @@ func (br *batchResults) earlyError() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
|
func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
|
||||||
if br.b != nil && br.qqIdx < len(br.b.queuedQueries) {
|
if br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
|
||||||
bi := br.b.queuedQueries[br.qqIdx]
|
bi := br.b.QueuedQueries[br.qqIdx]
|
||||||
query = bi.query
|
query = bi.SQL
|
||||||
args = bi.arguments
|
args = bi.Arguments
|
||||||
ok = true
|
ok = true
|
||||||
br.qqIdx++
|
br.qqIdx++
|
||||||
}
|
}
|
||||||
@@ -396,9 +396,9 @@ func (br *pipelineBatchResults) Close() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Read and run fn for all remaining items
|
// Read and run fn for all remaining items
|
||||||
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) {
|
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
|
||||||
if br.b.queuedQueries[br.qqIdx].fn != nil {
|
if br.b.QueuedQueries[br.qqIdx].fn != nil {
|
||||||
err := br.b.queuedQueries[br.qqIdx].fn(br)
|
err := br.b.QueuedQueries[br.qqIdx].fn(br)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
br.err = err
|
br.err = err
|
||||||
}
|
}
|
||||||
@@ -422,10 +422,10 @@ func (br *pipelineBatchResults) earlyError() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
|
func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
|
||||||
if br.b != nil && br.qqIdx < len(br.b.queuedQueries) {
|
if br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
|
||||||
bi := br.b.queuedQueries[br.qqIdx]
|
bi := br.b.QueuedQueries[br.qqIdx]
|
||||||
query = bi.query
|
query = bi.SQL
|
||||||
args = bi.arguments
|
args = bi.Arguments
|
||||||
ok = true
|
ok = true
|
||||||
br.qqIdx++
|
br.qqIdx++
|
||||||
}
|
}
|
||||||
|
|||||||
+2
-13
@@ -16,14 +16,8 @@ then
|
|||||||
|
|
||||||
cd testsetup
|
cd testsetup
|
||||||
|
|
||||||
# Generate a CA public / private key pair.
|
# Generate CA, server, and encrypted client certificates.
|
||||||
openssl genrsa -out ca.key 4096
|
go run generate_certs.go
|
||||||
openssl req -x509 -config ca.cnf -new -nodes -key ca.key -sha256 -days 365 -subj '/O=pgx-test-root' -out ca.pem
|
|
||||||
|
|
||||||
# Generate the certificate for localhost (the server).
|
|
||||||
openssl genrsa -out localhost.key 2048
|
|
||||||
openssl req -new -config localhost.cnf -key localhost.key -out localhost.csr
|
|
||||||
openssl x509 -req -in localhost.csr -CA ca.pem -CAkey ca.key -CAcreateserial -out localhost.crt -days 364 -sha256 -extfile localhost.cnf -extensions v3_req
|
|
||||||
|
|
||||||
# Copy certificates to server directory and set permissions.
|
# Copy certificates to server directory and set permissions.
|
||||||
sudo cp ca.pem /var/lib/postgresql/$PGVERSION/main/root.crt
|
sudo cp ca.pem /var/lib/postgresql/$PGVERSION/main/root.crt
|
||||||
@@ -34,11 +28,6 @@ then
|
|||||||
sudo cp localhost.crt /var/lib/postgresql/$PGVERSION/main/server.crt
|
sudo cp localhost.crt /var/lib/postgresql/$PGVERSION/main/server.crt
|
||||||
sudo chown postgres:postgres /var/lib/postgresql/$PGVERSION/main/server.crt
|
sudo chown postgres:postgres /var/lib/postgresql/$PGVERSION/main/server.crt
|
||||||
|
|
||||||
# Generate the certificate for client authentication.
|
|
||||||
openssl genrsa -des3 -out pgx_sslcert.key -passout pass:certpw 2048
|
|
||||||
openssl req -new -config pgx_sslcert.cnf -key pgx_sslcert.key -passin pass:certpw -out pgx_sslcert.csr
|
|
||||||
openssl x509 -req -in pgx_sslcert.csr -CA ca.pem -CAkey ca.key -CAcreateserial -out pgx_sslcert.crt -days 363 -sha256 -extfile pgx_sslcert.cnf -extensions v3_req
|
|
||||||
|
|
||||||
cp ca.pem /tmp
|
cp ca.pem /tmp
|
||||||
cp pgx_sslcert.key /tmp
|
cp pgx_sslcert.key /tmp
|
||||||
cp pgx_sslcert.crt /tmp
|
cp pgx_sslcert.crt /tmp
|
||||||
|
|||||||
@@ -903,10 +903,10 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
|
|||||||
return &batchResults{ctx: ctx, conn: c, err: err}
|
return &batchResults{ctx: ctx, conn: c, err: err}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, bi := range b.queuedQueries {
|
for _, bi := range b.QueuedQueries {
|
||||||
var queryRewriter QueryRewriter
|
var queryRewriter QueryRewriter
|
||||||
sql := bi.query
|
sql := bi.SQL
|
||||||
arguments := bi.arguments
|
arguments := bi.Arguments
|
||||||
|
|
||||||
optionLoop:
|
optionLoop:
|
||||||
for len(arguments) > 0 {
|
for len(arguments) > 0 {
|
||||||
@@ -928,8 +928,8 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bi.query = sql
|
bi.SQL = sql
|
||||||
bi.arguments = arguments
|
bi.Arguments = arguments
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: changing mode per batch? Update Batch.Queue function comment when implemented
|
// TODO: changing mode per batch? Update Batch.Queue function comment when implemented
|
||||||
@@ -939,8 +939,8 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// All other modes use extended protocol and thus can use prepared statements.
|
// All other modes use extended protocol and thus can use prepared statements.
|
||||||
for _, bi := range b.queuedQueries {
|
for _, bi := range b.QueuedQueries {
|
||||||
if sd, ok := c.preparedStatements[bi.query]; ok {
|
if sd, ok := c.preparedStatements[bi.SQL]; ok {
|
||||||
bi.sd = sd
|
bi.sd = sd
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -961,11 +961,11 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
|
|||||||
|
|
||||||
func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batch) *batchResults {
|
func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batch) *batchResults {
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
for i, bi := range b.queuedQueries {
|
for i, bi := range b.QueuedQueries {
|
||||||
if i > 0 {
|
if i > 0 {
|
||||||
sb.WriteByte(';')
|
sb.WriteByte(';')
|
||||||
}
|
}
|
||||||
sql, err := c.sanitizeForSimpleQuery(bi.query, bi.arguments...)
|
sql, err := c.sanitizeForSimpleQuery(bi.SQL, bi.Arguments...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &batchResults{ctx: ctx, conn: c, err: err}
|
return &batchResults{ctx: ctx, conn: c, err: err}
|
||||||
}
|
}
|
||||||
@@ -984,21 +984,21 @@ func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batc
|
|||||||
func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchResults {
|
func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchResults {
|
||||||
batch := &pgconn.Batch{}
|
batch := &pgconn.Batch{}
|
||||||
|
|
||||||
for _, bi := range b.queuedQueries {
|
for _, bi := range b.QueuedQueries {
|
||||||
sd := bi.sd
|
sd := bi.sd
|
||||||
if sd != nil {
|
if sd != nil {
|
||||||
err := c.eqb.Build(c.typeMap, sd, bi.arguments)
|
err := c.eqb.Build(c.typeMap, sd, bi.Arguments)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &batchResults{ctx: ctx, conn: c, err: err}
|
return &batchResults{ctx: ctx, conn: c, err: err}
|
||||||
}
|
}
|
||||||
|
|
||||||
batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats)
|
batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats)
|
||||||
} else {
|
} else {
|
||||||
err := c.eqb.Build(c.typeMap, nil, bi.arguments)
|
err := c.eqb.Build(c.typeMap, nil, bi.Arguments)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &batchResults{ctx: ctx, conn: c, err: err}
|
return &batchResults{ctx: ctx, conn: c, err: err}
|
||||||
}
|
}
|
||||||
batch.ExecParams(bi.query, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats)
|
batch.ExecParams(bi.SQL, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1023,18 +1023,18 @@ func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batc
|
|||||||
distinctNewQueries := []*pgconn.StatementDescription{}
|
distinctNewQueries := []*pgconn.StatementDescription{}
|
||||||
distinctNewQueriesIdxMap := make(map[string]int)
|
distinctNewQueriesIdxMap := make(map[string]int)
|
||||||
|
|
||||||
for _, bi := range b.queuedQueries {
|
for _, bi := range b.QueuedQueries {
|
||||||
if bi.sd == nil {
|
if bi.sd == nil {
|
||||||
sd := c.statementCache.Get(bi.query)
|
sd := c.statementCache.Get(bi.SQL)
|
||||||
if sd != nil {
|
if sd != nil {
|
||||||
bi.sd = sd
|
bi.sd = sd
|
||||||
} else {
|
} else {
|
||||||
if idx, present := distinctNewQueriesIdxMap[bi.query]; present {
|
if idx, present := distinctNewQueriesIdxMap[bi.SQL]; present {
|
||||||
bi.sd = distinctNewQueries[idx]
|
bi.sd = distinctNewQueries[idx]
|
||||||
} else {
|
} else {
|
||||||
sd = &pgconn.StatementDescription{
|
sd = &pgconn.StatementDescription{
|
||||||
Name: stmtcache.StatementName(bi.query),
|
Name: stmtcache.StatementName(bi.SQL),
|
||||||
SQL: bi.query,
|
SQL: bi.SQL,
|
||||||
}
|
}
|
||||||
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
|
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
|
||||||
distinctNewQueries = append(distinctNewQueries, sd)
|
distinctNewQueries = append(distinctNewQueries, sd)
|
||||||
@@ -1055,17 +1055,17 @@ func (c *Conn) sendBatchQueryExecModeCacheDescribe(ctx context.Context, b *Batch
|
|||||||
distinctNewQueries := []*pgconn.StatementDescription{}
|
distinctNewQueries := []*pgconn.StatementDescription{}
|
||||||
distinctNewQueriesIdxMap := make(map[string]int)
|
distinctNewQueriesIdxMap := make(map[string]int)
|
||||||
|
|
||||||
for _, bi := range b.queuedQueries {
|
for _, bi := range b.QueuedQueries {
|
||||||
if bi.sd == nil {
|
if bi.sd == nil {
|
||||||
sd := c.descriptionCache.Get(bi.query)
|
sd := c.descriptionCache.Get(bi.SQL)
|
||||||
if sd != nil {
|
if sd != nil {
|
||||||
bi.sd = sd
|
bi.sd = sd
|
||||||
} else {
|
} else {
|
||||||
if idx, present := distinctNewQueriesIdxMap[bi.query]; present {
|
if idx, present := distinctNewQueriesIdxMap[bi.SQL]; present {
|
||||||
bi.sd = distinctNewQueries[idx]
|
bi.sd = distinctNewQueries[idx]
|
||||||
} else {
|
} else {
|
||||||
sd = &pgconn.StatementDescription{
|
sd = &pgconn.StatementDescription{
|
||||||
SQL: bi.query,
|
SQL: bi.SQL,
|
||||||
}
|
}
|
||||||
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
|
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
|
||||||
distinctNewQueries = append(distinctNewQueries, sd)
|
distinctNewQueries = append(distinctNewQueries, sd)
|
||||||
@@ -1082,13 +1082,13 @@ func (c *Conn) sendBatchQueryExecModeDescribeExec(ctx context.Context, b *Batch)
|
|||||||
distinctNewQueries := []*pgconn.StatementDescription{}
|
distinctNewQueries := []*pgconn.StatementDescription{}
|
||||||
distinctNewQueriesIdxMap := make(map[string]int)
|
distinctNewQueriesIdxMap := make(map[string]int)
|
||||||
|
|
||||||
for _, bi := range b.queuedQueries {
|
for _, bi := range b.QueuedQueries {
|
||||||
if bi.sd == nil {
|
if bi.sd == nil {
|
||||||
if idx, present := distinctNewQueriesIdxMap[bi.query]; present {
|
if idx, present := distinctNewQueriesIdxMap[bi.SQL]; present {
|
||||||
bi.sd = distinctNewQueries[idx]
|
bi.sd = distinctNewQueries[idx]
|
||||||
} else {
|
} else {
|
||||||
sd := &pgconn.StatementDescription{
|
sd := &pgconn.StatementDescription{
|
||||||
SQL: bi.query,
|
SQL: bi.SQL,
|
||||||
}
|
}
|
||||||
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
|
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
|
||||||
distinctNewQueries = append(distinctNewQueries, sd)
|
distinctNewQueries = append(distinctNewQueries, sd)
|
||||||
@@ -1154,11 +1154,11 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Queue the queries.
|
// Queue the queries.
|
||||||
for _, bi := range b.queuedQueries {
|
for _, bi := range b.QueuedQueries {
|
||||||
err := c.eqb.Build(c.typeMap, bi.sd, bi.arguments)
|
err := c.eqb.Build(c.typeMap, bi.sd, bi.Arguments)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// we wrap the error so we the user can understand which query failed inside the batch
|
// we wrap the error so we the user can understand which query failed inside the batch
|
||||||
err = fmt.Errorf("error building query %s: %w", bi.query, err)
|
err = fmt.Errorf("error building query %s: %w", bi.SQL, err)
|
||||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
|
return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1203,7 +1203,15 @@ func (c *Conn) sanitizeForSimpleQuery(sql string, args ...any) (string, error) {
|
|||||||
return sanitize.SanitizeSQL(sql, valueArgs...)
|
return sanitize.SanitizeSQL(sql, valueArgs...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadType inspects the database for typeName and produces a pgtype.Type suitable for registration.
|
// LoadType inspects the database for typeName and produces a pgtype.Type suitable for registration. typeName must be
|
||||||
|
// the name of a type where the underlying type(s) is already understood by pgx. It is for derived types. In particular,
|
||||||
|
// typeName must be one of the following:
|
||||||
|
// - An array type name of a type that is already registered. e.g. "_foo" when "foo" is registered.
|
||||||
|
// - A composite type name where all field types are already registered.
|
||||||
|
// - A domain type name where the base type is already registered.
|
||||||
|
// - An enum type name.
|
||||||
|
// - A range type name where the element type is already registered.
|
||||||
|
// - A multirange type name where the element type is already registered.
|
||||||
func (c *Conn) LoadType(ctx context.Context, typeName string) (*pgtype.Type, error) {
|
func (c *Conn) LoadType(ctx context.Context, typeName string) (*pgtype.Type, error) {
|
||||||
var oid uint32
|
var oid uint32
|
||||||
|
|
||||||
@@ -1346,17 +1354,17 @@ order by attnum`,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error {
|
func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error {
|
||||||
if c.pgConn.TxStatus() != 'I' {
|
if txStatus := c.pgConn.TxStatus(); txStatus != 'I' && txStatus != 'T' {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.descriptionCache != nil {
|
if c.descriptionCache != nil {
|
||||||
c.descriptionCache.HandleInvalidated()
|
c.descriptionCache.RemoveInvalidated()
|
||||||
}
|
}
|
||||||
|
|
||||||
var invalidatedStatements []*pgconn.StatementDescription
|
var invalidatedStatements []*pgconn.StatementDescription
|
||||||
if c.statementCache != nil {
|
if c.statementCache != nil {
|
||||||
invalidatedStatements = c.statementCache.HandleInvalidated()
|
invalidatedStatements = c.statementCache.GetInvalidated()
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(invalidatedStatements) == 0 {
|
if len(invalidatedStatements) == 0 {
|
||||||
@@ -1368,7 +1376,6 @@ func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error
|
|||||||
|
|
||||||
for _, sd := range invalidatedStatements {
|
for _, sd := range invalidatedStatements {
|
||||||
pipeline.SendDeallocate(sd.Name)
|
pipeline.SendDeallocate(sd.Name)
|
||||||
delete(c.preparedStatements, sd.Name)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err := pipeline.Sync()
|
err := pipeline.Sync()
|
||||||
@@ -1381,5 +1388,10 @@ func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error
|
|||||||
return fmt.Errorf("failed to deallocate cached statement(s): %w", err)
|
return fmt.Errorf("failed to deallocate cached statement(s): %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.statementCache.RemoveInvalidated()
|
||||||
|
for _, sd := range invalidatedStatements {
|
||||||
|
delete(c.preparedStatements, sd.Name)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1338,3 +1338,73 @@ func TestRawValuesUnderlyingMemoryReused(t *testing.T) {
|
|||||||
t.Fatal("expected buffer from RawValues to be overwritten by subsequent queries but it was not")
|
t.Fatal("expected buffer from RawValues to be overwritten by subsequent queries but it was not")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// https://github.com/jackc/pgx/issues/1847
|
||||||
|
func TestConnDeallocateInvalidatedCachedStatementsWhenCanceled(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
pgxtest.SkipCockroachDB(t, conn, "CockroachDB returns decimal instead of integer for integer division")
|
||||||
|
|
||||||
|
var n int32
|
||||||
|
err := conn.QueryRow(ctx, "select 1 / $1::int", 1).Scan(&n)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, 1, n)
|
||||||
|
|
||||||
|
// Divide by zero causes an error. baseRows.Close() calls Invalidate on the statement cache whenever an error was
|
||||||
|
// encountered by the query. Use this to purposely invalidate the query. If we had access to private fields of conn
|
||||||
|
// we could call conn.statementCache.InvalidateAll() instead.
|
||||||
|
err = conn.QueryRow(ctx, "select 1 / $1::int", 0).Scan(&n)
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
ctx2, cancel2 := context.WithCancel(ctx)
|
||||||
|
cancel2()
|
||||||
|
err = conn.QueryRow(ctx2, "select 1 / $1::int", 1).Scan(&n)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.ErrorIs(t, err, context.Canceled)
|
||||||
|
|
||||||
|
err = conn.QueryRow(ctx, "select 1 / $1::int", 1).Scan(&n)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, 1, n)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/jackc/pgx/issues/1847
|
||||||
|
func TestConnDeallocateInvalidatedCachedStatementsInTransactionWithBatch(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
connString := os.Getenv("PGX_TEST_DATABASE")
|
||||||
|
config := mustParseConfig(t, connString)
|
||||||
|
config.DefaultQueryExecMode = pgx.QueryExecModeCacheStatement
|
||||||
|
config.StatementCacheCapacity = 2
|
||||||
|
|
||||||
|
conn, err := pgx.ConnectConfig(ctx, config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tx, err := conn.Begin(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer tx.Rollback(ctx)
|
||||||
|
|
||||||
|
_, err = tx.Exec(ctx, "select $1::int + 1", 1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = tx.Exec(ctx, "select $1::int + 2", 1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// This should invalidate the first cached statement.
|
||||||
|
_, err = tx.Exec(ctx, "select $1::int + 3", 1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
batch := &pgx.Batch{}
|
||||||
|
batch.Queue("select $1::int + 1", 1)
|
||||||
|
err = tx.SendBatch(ctx, batch).Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = tx.Rollback(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|||||||
@@ -803,6 +803,40 @@ func TestConnCopyFromAutomaticStringConversion(t *testing.T) {
|
|||||||
ensureConnValid(t, conn)
|
ensureConnValid(t, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// https://github.com/jackc/pgx/discussions/1891
|
||||||
|
func TestConnCopyFromAutomaticStringConversionArray(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
mustExec(t, conn, `create temporary table foo(
|
||||||
|
a numeric[]
|
||||||
|
)`)
|
||||||
|
|
||||||
|
inputRows := [][]interface{}{
|
||||||
|
{[]string{"42"}},
|
||||||
|
{[]string{"7"}},
|
||||||
|
{[]string{"8", "9"}},
|
||||||
|
{[][]string{{"10", "11"}, {"12", "13"}}},
|
||||||
|
}
|
||||||
|
|
||||||
|
copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, len(inputRows), copyCount)
|
||||||
|
|
||||||
|
// Test reads as int64 and flattened array for simplicity.
|
||||||
|
rows, _ := conn.Query(ctx, "select * from foo")
|
||||||
|
nums, err := pgx.CollectRows(rows, pgx.RowTo[[]int64])
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, [][]int64{{42}, {7}, {8, 9}, {10, 11, 12, 13}}, nums)
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
func TestCopyFromFunc(t *testing.T) {
|
func TestCopyFromFunc(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|||||||
@@ -3,9 +3,9 @@ module github.com/andoma-go/pgx/v5
|
|||||||
go 1.19
|
go 1.19
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/andoma-go/pgpassfile v1.0.0
|
github.com/andoma-go/pgpassfile v0.0.0-20240115130830-7bdd00f68544
|
||||||
github.com/andoma-go/pgservicefile v0.0.0-20240115131304-4a01ebf23c42
|
github.com/andoma-go/pgservicefile v0.0.0-20240115131304-4a01ebf23c42
|
||||||
github.com/andoma-go/puddle/v2 v2.2.1
|
github.com/andoma-go/puddle/v2 v2.0.0-20240328142435-357666cb6fa1
|
||||||
github.com/stretchr/testify v1.8.1
|
github.com/stretchr/testify v1.8.1
|
||||||
golang.org/x/crypto v0.17.0
|
golang.org/x/crypto v0.17.0
|
||||||
golang.org/x/text v0.14.0
|
golang.org/x/text v0.14.0
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
github.com/andoma-go/pgpassfile v1.0.0 h1:IJZAs6b/3pmEnq0kAvBWh2qEPsQOleHVIzMzj8WwT4w=
|
github.com/andoma-go/pgpassfile v0.0.0-20240115130830-7bdd00f68544 h1:zw0WuRyP2Awzl63MI2VwMXSM/CsNqwygHg/CbySE1ls=
|
||||||
github.com/andoma-go/pgpassfile v1.0.0/go.mod h1:JWSeNzz3oUhysdQgq1OL4PyV3R4QW/KyPvqqEykyN88=
|
github.com/andoma-go/pgpassfile v0.0.0-20240115130830-7bdd00f68544/go.mod h1:JWSeNzz3oUhysdQgq1OL4PyV3R4QW/KyPvqqEykyN88=
|
||||||
github.com/andoma-go/pgservicefile v0.0.0-20240115131304-4a01ebf23c42 h1:TpYPPFFHiqFDM0luTfDiHBdGSgYU+uloD+FaA87BBRk=
|
github.com/andoma-go/pgservicefile v0.0.0-20240115131304-4a01ebf23c42 h1:TpYPPFFHiqFDM0luTfDiHBdGSgYU+uloD+FaA87BBRk=
|
||||||
github.com/andoma-go/pgservicefile v0.0.0-20240115131304-4a01ebf23c42/go.mod h1:iRoNsjH6Wp9dCo0oiT1geVOjYusx6RUIdzCJNktFso0=
|
github.com/andoma-go/pgservicefile v0.0.0-20240115131304-4a01ebf23c42/go.mod h1:iRoNsjH6Wp9dCo0oiT1geVOjYusx6RUIdzCJNktFso0=
|
||||||
github.com/andoma-go/puddle/v2 v2.2.1 h1:cobxhnZmYsynXC9k8xcJd97ytlCa/Pe5kgj69pgncrE=
|
github.com/andoma-go/puddle/v2 v2.0.0-20240328142435-357666cb6fa1 h1:3/6Uu7EWnHeHAwZ9tfytqJy+1x8LTtYrsWGczhMJ4uc=
|
||||||
github.com/andoma-go/puddle/v2 v2.2.1/go.mod h1:iWHUHOdNa1/WJ6MyJAZ5qeTI/sJMbjVK/Gw4JLjh4Dw=
|
github.com/andoma-go/puddle/v2 v2.0.0-20240328142435-357666cb6fa1/go.mod h1:iWHUHOdNa1/WJ6MyJAZ5qeTI/sJMbjVK/Gw4JLjh4Dw=
|
||||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
|||||||
@@ -63,6 +63,10 @@ func (q *Query) Sanitize(args ...any) (string, error) {
|
|||||||
return "", fmt.Errorf("invalid arg type: %T", arg)
|
return "", fmt.Errorf("invalid arg type: %T", arg)
|
||||||
}
|
}
|
||||||
argUse[argIdx] = true
|
argUse[argIdx] = true
|
||||||
|
|
||||||
|
// Prevent SQL injection via Line Comment Creation
|
||||||
|
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
|
||||||
|
str = " " + str + " "
|
||||||
default:
|
default:
|
||||||
return "", fmt.Errorf("invalid Part type: %T", part)
|
return "", fmt.Errorf("invalid Part type: %T", part)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -132,47 +132,57 @@ func TestQuerySanitize(t *testing.T) {
|
|||||||
{
|
{
|
||||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||||
args: []any{int64(42)},
|
args: []any{int64(42)},
|
||||||
expected: `select 42`,
|
expected: `select 42 `,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||||
args: []any{float64(1.23)},
|
args: []any{float64(1.23)},
|
||||||
expected: `select 1.23`,
|
expected: `select 1.23 `,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||||
args: []any{true},
|
args: []any{true},
|
||||||
expected: `select true`,
|
expected: `select true `,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||||
args: []any{[]byte{0, 1, 2, 3, 255}},
|
args: []any{[]byte{0, 1, 2, 3, 255}},
|
||||||
expected: `select '\x00010203ff'`,
|
expected: `select '\x00010203ff' `,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||||
args: []any{nil},
|
args: []any{nil},
|
||||||
expected: `select null`,
|
expected: `select null `,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||||
args: []any{"foobar"},
|
args: []any{"foobar"},
|
||||||
expected: `select 'foobar'`,
|
expected: `select 'foobar' `,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||||
args: []any{"foo'bar"},
|
args: []any{"foo'bar"},
|
||||||
expected: `select 'foo''bar'`,
|
expected: `select 'foo''bar' `,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||||
args: []any{`foo\'bar`},
|
args: []any{`foo\'bar`},
|
||||||
expected: `select 'foo\''bar'`,
|
expected: `select 'foo\''bar' `,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
query: sanitize.Query{Parts: []sanitize.Part{"insert ", 1}},
|
query: sanitize.Query{Parts: []sanitize.Part{"insert ", 1}},
|
||||||
args: []any{time.Date(2020, time.March, 1, 23, 59, 59, 999999999, time.UTC)},
|
args: []any{time.Date(2020, time.March, 1, 23, 59, 59, 999999999, time.UTC)},
|
||||||
expected: `insert '2020-03-01 23:59:59.999999Z'`,
|
expected: `insert '2020-03-01 23:59:59.999999Z' `,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
query: sanitize.Query{Parts: []sanitize.Part{"select 1-", 1}},
|
||||||
|
args: []any{int64(-1)},
|
||||||
|
expected: `select 1- -1 `,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
query: sanitize.Query{Parts: []sanitize.Part{"select 1-", 1}},
|
||||||
|
args: []any{float64(-1)},
|
||||||
|
expected: `select 1- -1 `,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -81,12 +81,16 @@ func (c *LRUCache) InvalidateAll() {
|
|||||||
c.l = list.New()
|
c.l = list.New()
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandleInvalidated returns a slice of all statement descriptions invalidated since the last call to HandleInvalidated.
|
// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
|
||||||
// Typically, the caller will then deallocate them.
|
func (c *LRUCache) GetInvalidated() []*pgconn.StatementDescription {
|
||||||
func (c *LRUCache) HandleInvalidated() []*pgconn.StatementDescription {
|
return c.invalidStmts
|
||||||
invalidStmts := c.invalidStmts
|
}
|
||||||
|
|
||||||
|
// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a
|
||||||
|
// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
|
||||||
|
// never seen by the call to GetInvalidated.
|
||||||
|
func (c *LRUCache) RemoveInvalidated() {
|
||||||
c.invalidStmts = nil
|
c.invalidStmts = nil
|
||||||
return invalidStmts
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Len returns the number of cached prepared statement descriptions.
|
// Len returns the number of cached prepared statement descriptions.
|
||||||
|
|||||||
@@ -29,8 +29,13 @@ type Cache interface {
|
|||||||
// InvalidateAll invalidates all statement descriptions.
|
// InvalidateAll invalidates all statement descriptions.
|
||||||
InvalidateAll()
|
InvalidateAll()
|
||||||
|
|
||||||
// HandleInvalidated returns a slice of all statement descriptions invalidated since the last call to HandleInvalidated.
|
// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
|
||||||
HandleInvalidated() []*pgconn.StatementDescription
|
GetInvalidated() []*pgconn.StatementDescription
|
||||||
|
|
||||||
|
// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a
|
||||||
|
// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
|
||||||
|
// never seen by the call to GetInvalidated.
|
||||||
|
RemoveInvalidated()
|
||||||
|
|
||||||
// Len returns the number of cached prepared statement descriptions.
|
// Len returns the number of cached prepared statement descriptions.
|
||||||
Len() int
|
Len() int
|
||||||
|
|||||||
@@ -54,10 +54,16 @@ func (c *UnlimitedCache) InvalidateAll() {
|
|||||||
c.m = make(map[string]*pgconn.StatementDescription)
|
c.m = make(map[string]*pgconn.StatementDescription)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *UnlimitedCache) HandleInvalidated() []*pgconn.StatementDescription {
|
// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
|
||||||
invalidStmts := c.invalidStmts
|
func (c *UnlimitedCache) GetInvalidated() []*pgconn.StatementDescription {
|
||||||
|
return c.invalidStmts
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a
|
||||||
|
// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
|
||||||
|
// never seen by the call to GetInvalidated.
|
||||||
|
func (c *UnlimitedCache) RemoveInvalidated() {
|
||||||
c.invalidStmts = nil
|
c.invalidStmts = nil
|
||||||
return invalidStmts
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Len returns the number of cached prepared statement descriptions.
|
// Len returns the number of cached prepared statement descriptions.
|
||||||
|
|||||||
+55
-26
@@ -6,6 +6,11 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// The PostgreSQL wire protocol has a limit of 1 GB - 1 per message. See definition of
|
||||||
|
// PQ_LARGE_MESSAGE_LIMIT in the PostgreSQL source code. To allow for the other data
|
||||||
|
// in the message,maxLargeObjectMessageLength should be no larger than 1 GB - 1 KB.
|
||||||
|
var maxLargeObjectMessageLength = 1024*1024*1024 - 1024
|
||||||
|
|
||||||
// LargeObjects is a structure used to access the large objects API. It is only valid within the transaction where it
|
// LargeObjects is a structure used to access the large objects API. It is only valid within the transaction where it
|
||||||
// was created.
|
// was created.
|
||||||
//
|
//
|
||||||
@@ -67,41 +72,65 @@ type LargeObject struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Write writes p to the large object and returns the number of bytes written and an error if not all of p was written.
|
// Write writes p to the large object and returns the number of bytes written and an error if not all of p was written.
|
||||||
//
|
|
||||||
// Write is implemented with a single call to lowrite. The PostgreSQL wire protocol has a limit of 1 GB - 1 per message.
|
|
||||||
// See definition of PQ_LARGE_MESSAGE_LIMIT in the PostgreSQL source code. To allow for the other data in the message,
|
|
||||||
// len(p) should be no larger than 1 GB - 1 KB.
|
|
||||||
func (o *LargeObject) Write(p []byte) (int, error) {
|
func (o *LargeObject) Write(p []byte) (int, error) {
|
||||||
var n int
|
nTotal := 0
|
||||||
err := o.tx.QueryRow(o.ctx, "select lowrite($1, $2)", o.fd, p).Scan(&n)
|
for {
|
||||||
if err != nil {
|
expected := len(p) - nTotal
|
||||||
return n, err
|
if expected == 0 {
|
||||||
|
break
|
||||||
|
} else if expected > maxLargeObjectMessageLength {
|
||||||
|
expected = maxLargeObjectMessageLength
|
||||||
|
}
|
||||||
|
|
||||||
|
var n int
|
||||||
|
err := o.tx.QueryRow(o.ctx, "select lowrite($1, $2)", o.fd, p[nTotal:nTotal+expected]).Scan(&n)
|
||||||
|
if err != nil {
|
||||||
|
return nTotal, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if n < 0 {
|
||||||
|
return nTotal, errors.New("failed to write to large object")
|
||||||
|
}
|
||||||
|
|
||||||
|
nTotal += n
|
||||||
|
|
||||||
|
if n < expected {
|
||||||
|
return nTotal, errors.New("short write to large object")
|
||||||
|
} else if n > expected {
|
||||||
|
return nTotal, errors.New("invalid write to large object")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if n < 0 {
|
return nTotal, nil
|
||||||
return 0, errors.New("failed to write to large object")
|
|
||||||
}
|
|
||||||
|
|
||||||
return n, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read reads up to len(p) bytes into p returning the number of bytes read.
|
// Read reads up to len(p) bytes into p returning the number of bytes read.
|
||||||
//
|
|
||||||
// Read is implemented with a single call to loread. PostgreSQL internally allocates a single buffer for the response.
|
|
||||||
// The largest buffer PostgreSQL will allocate is 1 GB - 1. See definition of MaxAllocSize in the PostgreSQL source
|
|
||||||
// code. To allow for the other data in the message, len(p) should be no larger than 1 GB - 1 KB.
|
|
||||||
func (o *LargeObject) Read(p []byte) (int, error) {
|
func (o *LargeObject) Read(p []byte) (int, error) {
|
||||||
var res []byte
|
nTotal := 0
|
||||||
err := o.tx.QueryRow(o.ctx, "select loread($1, $2)", o.fd, len(p)).Scan(&res)
|
for {
|
||||||
copy(p, res)
|
expected := len(p) - nTotal
|
||||||
if err != nil {
|
if expected == 0 {
|
||||||
return len(res), err
|
break
|
||||||
|
} else if expected > maxLargeObjectMessageLength {
|
||||||
|
expected = maxLargeObjectMessageLength
|
||||||
|
}
|
||||||
|
|
||||||
|
var res []byte
|
||||||
|
err := o.tx.QueryRow(o.ctx, "select loread($1, $2)", o.fd, expected).Scan(&res)
|
||||||
|
copy(p[nTotal:], res)
|
||||||
|
nTotal += len(res)
|
||||||
|
if err != nil {
|
||||||
|
return nTotal, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(res) < expected {
|
||||||
|
return nTotal, io.EOF
|
||||||
|
} else if len(res) > expected {
|
||||||
|
return nTotal, errors.New("invalid read of large object")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(res) < len(p) {
|
return nTotal, nil
|
||||||
err = io.EOF
|
|
||||||
}
|
|
||||||
return len(res), err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Seek moves the current location pointer to the new location specified by offset.
|
// Seek moves the current location pointer to the new location specified by offset.
|
||||||
|
|||||||
@@ -0,0 +1,20 @@
|
|||||||
|
package pgx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SetMaxLargeObjectMessageLength sets internal maxLargeObjectMessageLength variable
|
||||||
|
// to the given length for the duration of the test.
|
||||||
|
//
|
||||||
|
// Tests using this helper should not use t.Parallel().
|
||||||
|
func SetMaxLargeObjectMessageLength(t *testing.T, length int) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
original := maxLargeObjectMessageLength
|
||||||
|
t.Cleanup(func() {
|
||||||
|
maxLargeObjectMessageLength = original
|
||||||
|
})
|
||||||
|
|
||||||
|
maxLargeObjectMessageLength = length
|
||||||
|
}
|
||||||
@@ -13,7 +13,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestLargeObjects(t *testing.T) {
|
func TestLargeObjects(t *testing.T) {
|
||||||
t.Parallel()
|
// We use a very short limit to test chunking logic.
|
||||||
|
pgx.SetMaxLargeObjectMessageLength(t, 2)
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
@@ -34,7 +35,8 @@ func TestLargeObjects(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestLargeObjectsSimpleProtocol(t *testing.T) {
|
func TestLargeObjectsSimpleProtocol(t *testing.T) {
|
||||||
t.Parallel()
|
// We use a very short limit to test chunking logic.
|
||||||
|
pgx.SetMaxLargeObjectMessageLength(t, 2)
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
@@ -160,7 +162,8 @@ func testLargeObjects(t *testing.T, ctx context.Context, tx pgx.Tx) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestLargeObjectsMultipleTransactions(t *testing.T) {
|
func TestLargeObjectsMultipleTransactions(t *testing.T) {
|
||||||
t.Parallel()
|
// We use a very short limit to test chunking logic.
|
||||||
|
pgx.SetMaxLargeObjectMessageLength(t, 2)
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|||||||
+42
-16
@@ -2,6 +2,7 @@ package pgx
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
@@ -21,6 +22,34 @@ type NamedArgs map[string]any
|
|||||||
|
|
||||||
// RewriteQuery implements the QueryRewriter interface.
|
// RewriteQuery implements the QueryRewriter interface.
|
||||||
func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
|
func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
|
||||||
|
return rewriteQuery(na, sql, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// StrictNamedArgs can be used in the same way as NamedArgs, but provided arguments are also checked to include all
|
||||||
|
// named arguments that the sql query uses, and no extra arguments.
|
||||||
|
type StrictNamedArgs map[string]any
|
||||||
|
|
||||||
|
// RewriteQuery implements the QueryRewriter interface.
|
||||||
|
func (sna StrictNamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
|
||||||
|
return rewriteQuery(sna, sql, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
type namedArg string
|
||||||
|
|
||||||
|
type sqlLexer struct {
|
||||||
|
src string
|
||||||
|
start int
|
||||||
|
pos int
|
||||||
|
nested int // multiline comment nesting level.
|
||||||
|
stateFn stateFn
|
||||||
|
parts []any
|
||||||
|
|
||||||
|
nameToOrdinal map[namedArg]int
|
||||||
|
}
|
||||||
|
|
||||||
|
type stateFn func(*sqlLexer) stateFn
|
||||||
|
|
||||||
|
func rewriteQuery(na map[string]any, sql string, isStrict bool) (newSQL string, newArgs []any, err error) {
|
||||||
l := &sqlLexer{
|
l := &sqlLexer{
|
||||||
src: sql,
|
src: sql,
|
||||||
stateFn: rawState,
|
stateFn: rawState,
|
||||||
@@ -44,27 +73,24 @@ func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, ar
|
|||||||
|
|
||||||
newArgs = make([]any, len(l.nameToOrdinal))
|
newArgs = make([]any, len(l.nameToOrdinal))
|
||||||
for name, ordinal := range l.nameToOrdinal {
|
for name, ordinal := range l.nameToOrdinal {
|
||||||
newArgs[ordinal-1] = na[string(name)]
|
var found bool
|
||||||
|
newArgs[ordinal-1], found = na[string(name)]
|
||||||
|
if isStrict && !found {
|
||||||
|
return "", nil, fmt.Errorf("argument %s found in sql query but not present in StrictNamedArgs", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if isStrict {
|
||||||
|
for name := range na {
|
||||||
|
if _, found := l.nameToOrdinal[namedArg(name)]; !found {
|
||||||
|
return "", nil, fmt.Errorf("argument %s of StrictNamedArgs not found in sql query", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return sb.String(), newArgs, nil
|
return sb.String(), newArgs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type namedArg string
|
|
||||||
|
|
||||||
type sqlLexer struct {
|
|
||||||
src string
|
|
||||||
start int
|
|
||||||
pos int
|
|
||||||
nested int // multiline comment nesting level.
|
|
||||||
stateFn stateFn
|
|
||||||
parts []any
|
|
||||||
|
|
||||||
nameToOrdinal map[namedArg]int
|
|
||||||
}
|
|
||||||
|
|
||||||
type stateFn func(*sqlLexer) stateFn
|
|
||||||
|
|
||||||
func rawState(l *sqlLexer) stateFn {
|
func rawState(l *sqlLexer) stateFn {
|
||||||
for {
|
for {
|
||||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
|||||||
@@ -93,6 +93,18 @@ func TestNamedArgsRewriteQuery(t *testing.T) {
|
|||||||
where id = $1;`,
|
where id = $1;`,
|
||||||
expectedArgs: []any{int32(42)},
|
expectedArgs: []any{int32(42)},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
sql: "extra provided argument",
|
||||||
|
namedArgs: pgx.NamedArgs{"extra": int32(1)},
|
||||||
|
expectedSQL: "extra provided argument",
|
||||||
|
expectedArgs: []any{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sql: "@missing argument",
|
||||||
|
namedArgs: pgx.NamedArgs{},
|
||||||
|
expectedSQL: "$1 argument",
|
||||||
|
expectedArgs: []any{nil},
|
||||||
|
},
|
||||||
|
|
||||||
// test comments and quotes
|
// test comments and quotes
|
||||||
} {
|
} {
|
||||||
@@ -102,3 +114,49 @@ func TestNamedArgsRewriteQuery(t *testing.T) {
|
|||||||
assert.Equalf(t, tt.expectedArgs, args, "%d", i)
|
assert.Equalf(t, tt.expectedArgs, args, "%d", i)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStrictNamedArgsRewriteQuery(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
for i, tt := range []struct {
|
||||||
|
sql string
|
||||||
|
namedArgs pgx.StrictNamedArgs
|
||||||
|
expectedSQL string
|
||||||
|
expectedArgs []any
|
||||||
|
isExpectedError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
sql: "no arguments",
|
||||||
|
namedArgs: pgx.StrictNamedArgs{},
|
||||||
|
expectedSQL: "no arguments",
|
||||||
|
expectedArgs: []any{},
|
||||||
|
isExpectedError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sql: "@all @matches",
|
||||||
|
namedArgs: pgx.StrictNamedArgs{"all": int32(1), "matches": int32(2)},
|
||||||
|
expectedSQL: "$1 $2",
|
||||||
|
expectedArgs: []any{int32(1), int32(2)},
|
||||||
|
isExpectedError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sql: "extra provided argument",
|
||||||
|
namedArgs: pgx.StrictNamedArgs{"extra": int32(1)},
|
||||||
|
isExpectedError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sql: "@missing argument",
|
||||||
|
namedArgs: pgx.StrictNamedArgs{},
|
||||||
|
isExpectedError: true,
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
sql, args, err := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, nil)
|
||||||
|
if tt.isExpectedError {
|
||||||
|
assert.Errorf(t, err, "%d", i)
|
||||||
|
} else {
|
||||||
|
require.NoErrorf(t, err, "%d", i)
|
||||||
|
assert.Equalf(t, tt.expectedSQL, sql, "%d", i)
|
||||||
|
assert.Equalf(t, tt.expectedArgs, args, "%d", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -721,6 +721,9 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
|
|||||||
return nil, fmt.Errorf("unable to read sslkey: %w", err)
|
return nil, fmt.Errorf("unable to read sslkey: %w", err)
|
||||||
}
|
}
|
||||||
block, _ := pem.Decode(buf)
|
block, _ := pem.Decode(buf)
|
||||||
|
if block == nil {
|
||||||
|
return nil, errors.New("failed to decode sslkey")
|
||||||
|
}
|
||||||
var pemKey []byte
|
var pemKey []byte
|
||||||
var decryptedKey []byte
|
var decryptedKey []byte
|
||||||
var decryptedError error
|
var decryptedError error
|
||||||
|
|||||||
+43
-5
@@ -1674,25 +1674,55 @@ func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) {
|
|||||||
// Batch is a collection of queries that can be sent to the PostgreSQL server in a single round-trip.
|
// Batch is a collection of queries that can be sent to the PostgreSQL server in a single round-trip.
|
||||||
type Batch struct {
|
type Batch struct {
|
||||||
buf []byte
|
buf []byte
|
||||||
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExecParams appends an ExecParams command to the batch. See PgConn.ExecParams for parameter descriptions.
|
// ExecParams appends an ExecParams command to the batch. See PgConn.ExecParams for parameter descriptions.
|
||||||
func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) {
|
func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) {
|
||||||
batch.buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf)
|
if batch.err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
batch.buf, batch.err = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf)
|
||||||
|
if batch.err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
batch.ExecPrepared("", paramValues, paramFormats, resultFormats)
|
batch.ExecPrepared("", paramValues, paramFormats, resultFormats)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExecPrepared appends an ExecPrepared e command to the batch. See PgConn.ExecPrepared for parameter descriptions.
|
// ExecPrepared appends an ExecPrepared e command to the batch. See PgConn.ExecPrepared for parameter descriptions.
|
||||||
func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) {
|
func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) {
|
||||||
batch.buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf)
|
if batch.err != nil {
|
||||||
batch.buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf)
|
return
|
||||||
batch.buf = (&pgproto3.Execute{}).Encode(batch.buf)
|
}
|
||||||
|
|
||||||
|
batch.buf, batch.err = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf)
|
||||||
|
if batch.err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
batch.buf, batch.err = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf)
|
||||||
|
if batch.err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
batch.buf, batch.err = (&pgproto3.Execute{}).Encode(batch.buf)
|
||||||
|
if batch.err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a
|
// ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a
|
||||||
// transaction is already in progress or SQL contains transaction control statements. This is a simpler way of executing
|
// transaction is already in progress or SQL contains transaction control statements. This is a simpler way of executing
|
||||||
// multiple queries in a single round trip than using pipeline mode.
|
// multiple queries in a single round trip than using pipeline mode.
|
||||||
func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader {
|
func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader {
|
||||||
|
if batch.err != nil {
|
||||||
|
return &MultiResultReader{
|
||||||
|
closed: true,
|
||||||
|
err: batch.err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err := pgConn.lock(); err != nil {
|
if err := pgConn.lock(); err != nil {
|
||||||
return &MultiResultReader{
|
return &MultiResultReader{
|
||||||
closed: true,
|
closed: true,
|
||||||
@@ -1718,7 +1748,13 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
|
|||||||
pgConn.contextWatcher.Watch(ctx)
|
pgConn.contextWatcher.Watch(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
batch.buf = (&pgproto3.Sync{}).Encode(batch.buf)
|
batch.buf, batch.err = (&pgproto3.Sync{}).Encode(batch.buf)
|
||||||
|
if batch.err != nil {
|
||||||
|
multiResult.closed = true
|
||||||
|
multiResult.err = batch.err
|
||||||
|
pgConn.unlock()
|
||||||
|
return multiResult
|
||||||
|
}
|
||||||
|
|
||||||
pgConn.enterPotentialWriteReadDeadlock()
|
pgConn.enterPotentialWriteReadDeadlock()
|
||||||
defer pgConn.exitPotentialWriteReadDeadlock()
|
defer pgConn.exitPotentialWriteReadDeadlock()
|
||||||
@@ -2094,6 +2130,8 @@ func (p *Pipeline) getResults() (results any, err error) {
|
|||||||
for {
|
for {
|
||||||
msg, err := p.conn.receiveMessage()
|
msg, err := p.conn.receiveMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
p.closed = true
|
||||||
|
p.err = err
|
||||||
p.conn.asyncClose()
|
p.conn.asyncClose()
|
||||||
return nil, normalizeTimeoutError(p.ctx, err)
|
return nil, normalizeTimeoutError(p.ctx, err)
|
||||||
}
|
}
|
||||||
|
|||||||
+93
-3
@@ -3363,9 +3363,9 @@ func TestSNISupport(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
srv.Write((&pgproto3.AuthenticationOk{}).Encode(nil))
|
srv.Write(mustEncode((&pgproto3.AuthenticationOk{}).Encode(nil)))
|
||||||
srv.Write((&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}).Encode(nil))
|
srv.Write(mustEncode((&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}).Encode(nil)))
|
||||||
srv.Write((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil))
|
srv.Write(mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil)))
|
||||||
|
|
||||||
serverSNINameChan <- sniHost
|
serverSNINameChan <- sniHost
|
||||||
}()
|
}()
|
||||||
@@ -3389,3 +3389,93 @@ func TestSNISupport(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// https://github.com/jackc/pgx/issues/1920
|
||||||
|
func TestFatalErrorReceivedInPipelineMode(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
steps := pgmock.AcceptUnauthenticatedConnRequestSteps()
|
||||||
|
steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Parse{}))
|
||||||
|
steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Describe{}))
|
||||||
|
steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Parse{}))
|
||||||
|
steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Describe{}))
|
||||||
|
steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Parse{}))
|
||||||
|
steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Describe{}))
|
||||||
|
steps = append(steps, pgmock.SendMessage(&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{
|
||||||
|
{Name: []byte("mock")},
|
||||||
|
}}))
|
||||||
|
steps = append(steps, pgmock.SendMessage(&pgproto3.ErrorResponse{Severity: "FATAL", Code: "57P01"}))
|
||||||
|
// We shouldn't get anything after the first fatal error. But the reported issue was with PgBouncer so maybe that
|
||||||
|
// causes the issue. Anyway, a FATAL error after the connection had already been killed could cause a panic.
|
||||||
|
steps = append(steps, pgmock.SendMessage(&pgproto3.ErrorResponse{Severity: "FATAL", Code: "57P01"}))
|
||||||
|
|
||||||
|
script := &pgmock.Script{Steps: steps}
|
||||||
|
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
|
serverKeepAlive := make(chan struct{})
|
||||||
|
defer close(serverKeepAlive)
|
||||||
|
|
||||||
|
serverErrChan := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
defer close(serverErrChan)
|
||||||
|
|
||||||
|
conn, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
serverErrChan <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
err = conn.SetDeadline(time.Now().Add(59 * time.Second))
|
||||||
|
if err != nil {
|
||||||
|
serverErrChan <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = script.Run(pgproto3.NewBackend(conn, conn))
|
||||||
|
if err != nil {
|
||||||
|
serverErrChan <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
<-serverKeepAlive
|
||||||
|
}()
|
||||||
|
|
||||||
|
parts := strings.Split(ln.Addr().String(), ":")
|
||||||
|
host := parts[0]
|
||||||
|
port := parts[1]
|
||||||
|
connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port)
|
||||||
|
|
||||||
|
ctx, cancel = context.WithTimeout(ctx, 59*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
conn, err := pgconn.Connect(ctx, connStr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pipeline := conn.StartPipeline(ctx)
|
||||||
|
pipeline.SendPrepare("s1", "select 1", nil)
|
||||||
|
pipeline.SendPrepare("s2", "select 2", nil)
|
||||||
|
pipeline.SendPrepare("s3", "select 3", nil)
|
||||||
|
err = pipeline.Sync()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = pipeline.GetResults()
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = pipeline.GetResults()
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
err = pipeline.Close()
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustEncode(buf []byte, err error) []byte {
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|||||||
@@ -35,11 +35,10 @@ func (dst *AuthenticationCleartextPassword) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *AuthenticationCleartextPassword) Encode(dst []byte) []byte {
|
func (src *AuthenticationCleartextPassword) Encode(dst []byte) ([]byte, error) {
|
||||||
dst = append(dst, 'R')
|
dst, sp := beginMessage(dst, 'R')
|
||||||
dst = pgio.AppendInt32(dst, 8)
|
|
||||||
dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword)
|
dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword)
|
||||||
return dst
|
return finishMessage(dst, sp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
@@ -27,11 +27,10 @@ func (a *AuthenticationGSS) Decode(src []byte) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AuthenticationGSS) Encode(dst []byte) []byte {
|
func (a *AuthenticationGSS) Encode(dst []byte) ([]byte, error) {
|
||||||
dst = append(dst, 'R')
|
dst, sp := beginMessage(dst, 'R')
|
||||||
dst = pgio.AppendInt32(dst, 4)
|
|
||||||
dst = pgio.AppendUint32(dst, AuthTypeGSS)
|
dst = pgio.AppendUint32(dst, AuthTypeGSS)
|
||||||
return dst
|
return finishMessage(dst, sp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AuthenticationGSS) MarshalJSON() ([]byte, error) {
|
func (a *AuthenticationGSS) MarshalJSON() ([]byte, error) {
|
||||||
|
|||||||
@@ -31,12 +31,11 @@ func (a *AuthenticationGSSContinue) Decode(src []byte) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AuthenticationGSSContinue) Encode(dst []byte) []byte {
|
func (a *AuthenticationGSSContinue) Encode(dst []byte) ([]byte, error) {
|
||||||
dst = append(dst, 'R')
|
dst, sp := beginMessage(dst, 'R')
|
||||||
dst = pgio.AppendInt32(dst, int32(len(a.Data))+8)
|
|
||||||
dst = pgio.AppendUint32(dst, AuthTypeGSSCont)
|
dst = pgio.AppendUint32(dst, AuthTypeGSSCont)
|
||||||
dst = append(dst, a.Data...)
|
dst = append(dst, a.Data...)
|
||||||
return dst
|
return finishMessage(dst, sp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AuthenticationGSSContinue) MarshalJSON() ([]byte, error) {
|
func (a *AuthenticationGSSContinue) MarshalJSON() ([]byte, error) {
|
||||||
|
|||||||
@@ -38,12 +38,11 @@ func (dst *AuthenticationMD5Password) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *AuthenticationMD5Password) Encode(dst []byte) []byte {
|
func (src *AuthenticationMD5Password) Encode(dst []byte) ([]byte, error) {
|
||||||
dst = append(dst, 'R')
|
dst, sp := beginMessage(dst, 'R')
|
||||||
dst = pgio.AppendInt32(dst, 12)
|
|
||||||
dst = pgio.AppendUint32(dst, AuthTypeMD5Password)
|
dst = pgio.AppendUint32(dst, AuthTypeMD5Password)
|
||||||
dst = append(dst, src.Salt[:]...)
|
dst = append(dst, src.Salt[:]...)
|
||||||
return dst
|
return finishMessage(dst, sp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
@@ -35,11 +35,10 @@ func (dst *AuthenticationOk) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *AuthenticationOk) Encode(dst []byte) []byte {
|
func (src *AuthenticationOk) Encode(dst []byte) ([]byte, error) {
|
||||||
dst = append(dst, 'R')
|
dst, sp := beginMessage(dst, 'R')
|
||||||
dst = pgio.AppendInt32(dst, 8)
|
|
||||||
dst = pgio.AppendUint32(dst, AuthTypeOk)
|
dst = pgio.AppendUint32(dst, AuthTypeOk)
|
||||||
return dst
|
return finishMessage(dst, sp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
@@ -47,10 +47,8 @@ func (dst *AuthenticationSASL) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *AuthenticationSASL) Encode(dst []byte) []byte {
|
func (src *AuthenticationSASL) Encode(dst []byte) ([]byte, error) {
|
||||||
dst = append(dst, 'R')
|
dst, sp := beginMessage(dst, 'R')
|
||||||
sp := len(dst)
|
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
|
||||||
dst = pgio.AppendUint32(dst, AuthTypeSASL)
|
dst = pgio.AppendUint32(dst, AuthTypeSASL)
|
||||||
|
|
||||||
for _, s := range src.AuthMechanisms {
|
for _, s := range src.AuthMechanisms {
|
||||||
@@ -59,9 +57,7 @@ func (src *AuthenticationSASL) Encode(dst []byte) []byte {
|
|||||||
}
|
}
|
||||||
dst = append(dst, 0)
|
dst = append(dst, 0)
|
||||||
|
|
||||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
return finishMessage(dst, sp)
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
@@ -38,17 +38,11 @@ func (dst *AuthenticationSASLContinue) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte {
|
func (src *AuthenticationSASLContinue) Encode(dst []byte) ([]byte, error) {
|
||||||
dst = append(dst, 'R')
|
dst, sp := beginMessage(dst, 'R')
|
||||||
sp := len(dst)
|
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
|
||||||
dst = pgio.AppendUint32(dst, AuthTypeSASLContinue)
|
dst = pgio.AppendUint32(dst, AuthTypeSASLContinue)
|
||||||
|
|
||||||
dst = append(dst, src.Data...)
|
dst = append(dst, src.Data...)
|
||||||
|
return finishMessage(dst, sp)
|
||||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
@@ -38,17 +38,11 @@ func (dst *AuthenticationSASLFinal) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte {
|
func (src *AuthenticationSASLFinal) Encode(dst []byte) ([]byte, error) {
|
||||||
dst = append(dst, 'R')
|
dst, sp := beginMessage(dst, 'R')
|
||||||
sp := len(dst)
|
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
|
||||||
dst = pgio.AppendUint32(dst, AuthTypeSASLFinal)
|
dst = pgio.AppendUint32(dst, AuthTypeSASLFinal)
|
||||||
|
|
||||||
dst = append(dst, src.Data...)
|
dst = append(dst, src.Data...)
|
||||||
|
return finishMessage(dst, sp)
|
||||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Unmarshaler.
|
// MarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
|||||||
+21
-4
@@ -16,7 +16,8 @@ type Backend struct {
|
|||||||
// before it is actually transmitted (i.e. before Flush).
|
// before it is actually transmitted (i.e. before Flush).
|
||||||
tracer *tracer
|
tracer *tracer
|
||||||
|
|
||||||
wbuf []byte
|
wbuf []byte
|
||||||
|
encodeError error
|
||||||
|
|
||||||
// Frontend message flyweights
|
// Frontend message flyweights
|
||||||
bind Bind
|
bind Bind
|
||||||
@@ -55,11 +56,21 @@ func NewBackend(r io.Reader, w io.Writer) *Backend {
|
|||||||
return &Backend{cr: cr, w: w}
|
return &Backend{cr: cr, w: w}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send sends a message to the frontend (i.e. the client). The message is not guaranteed to be written until Flush is
|
// Send sends a message to the frontend (i.e. the client). The message is buffered until Flush is called. Any error
|
||||||
// called.
|
// encountered will be returned from Flush.
|
||||||
func (b *Backend) Send(msg BackendMessage) {
|
func (b *Backend) Send(msg BackendMessage) {
|
||||||
|
if b.encodeError != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
prevLen := len(b.wbuf)
|
prevLen := len(b.wbuf)
|
||||||
b.wbuf = msg.Encode(b.wbuf)
|
newBuf, err := msg.Encode(b.wbuf)
|
||||||
|
if err != nil {
|
||||||
|
b.encodeError = err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
b.wbuf = newBuf
|
||||||
|
|
||||||
if b.tracer != nil {
|
if b.tracer != nil {
|
||||||
b.tracer.traceMessage('B', int32(len(b.wbuf)-prevLen), msg)
|
b.tracer.traceMessage('B', int32(len(b.wbuf)-prevLen), msg)
|
||||||
}
|
}
|
||||||
@@ -67,6 +78,12 @@ func (b *Backend) Send(msg BackendMessage) {
|
|||||||
|
|
||||||
// Flush writes any pending messages to the frontend (i.e. the client).
|
// Flush writes any pending messages to the frontend (i.e. the client).
|
||||||
func (b *Backend) Flush() error {
|
func (b *Backend) Flush() error {
|
||||||
|
if err := b.encodeError; err != nil {
|
||||||
|
b.encodeError = nil
|
||||||
|
b.wbuf = b.wbuf[:0]
|
||||||
|
return &writeError{err: err, safeToRetry: true}
|
||||||
|
}
|
||||||
|
|
||||||
n, err := b.w.Write(b.wbuf)
|
n, err := b.w.Write(b.wbuf)
|
||||||
|
|
||||||
const maxLen = 1024
|
const maxLen = 1024
|
||||||
|
|||||||
@@ -29,12 +29,11 @@ func (dst *BackendKeyData) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *BackendKeyData) Encode(dst []byte) []byte {
|
func (src *BackendKeyData) Encode(dst []byte) ([]byte, error) {
|
||||||
dst = append(dst, 'K')
|
dst, sp := beginMessage(dst, 'K')
|
||||||
dst = pgio.AppendUint32(dst, 12)
|
|
||||||
dst = pgio.AppendUint32(dst, src.ProcessID)
|
dst = pgio.AppendUint32(dst, src.ProcessID)
|
||||||
dst = pgio.AppendUint32(dst, src.SecretKey)
|
dst = pgio.AppendUint32(dst, src.SecretKey)
|
||||||
return dst
|
return finishMessage(dst, sp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
@@ -71,8 +71,8 @@ func TestStartupMessage(t *testing.T) {
|
|||||||
"username": "tester",
|
"username": "tester",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
dst := []byte{}
|
dst, err := want.Encode([]byte{})
|
||||||
dst = want.Encode(dst)
|
require.NoError(t, err)
|
||||||
|
|
||||||
server := &interruptReader{}
|
server := &interruptReader{}
|
||||||
server.push(dst)
|
server.push(dst)
|
||||||
|
|||||||
+14
-7
@@ -5,7 +5,9 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math"
|
||||||
|
|
||||||
"github.com/andoma-go/pgx/v5/internal/pgio"
|
"github.com/andoma-go/pgx/v5/internal/pgio"
|
||||||
)
|
)
|
||||||
@@ -108,21 +110,25 @@ func (dst *Bind) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *Bind) Encode(dst []byte) []byte {
|
func (src *Bind) Encode(dst []byte) ([]byte, error) {
|
||||||
dst = append(dst, 'B')
|
dst, sp := beginMessage(dst, 'B')
|
||||||
sp := len(dst)
|
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
|
||||||
|
|
||||||
dst = append(dst, src.DestinationPortal...)
|
dst = append(dst, src.DestinationPortal...)
|
||||||
dst = append(dst, 0)
|
dst = append(dst, 0)
|
||||||
dst = append(dst, src.PreparedStatement...)
|
dst = append(dst, src.PreparedStatement...)
|
||||||
dst = append(dst, 0)
|
dst = append(dst, 0)
|
||||||
|
|
||||||
|
if len(src.ParameterFormatCodes) > math.MaxUint16 {
|
||||||
|
return nil, errors.New("too many parameter format codes")
|
||||||
|
}
|
||||||
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes)))
|
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes)))
|
||||||
for _, fc := range src.ParameterFormatCodes {
|
for _, fc := range src.ParameterFormatCodes {
|
||||||
dst = pgio.AppendInt16(dst, fc)
|
dst = pgio.AppendInt16(dst, fc)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(src.Parameters) > math.MaxUint16 {
|
||||||
|
return nil, errors.New("too many parameters")
|
||||||
|
}
|
||||||
dst = pgio.AppendUint16(dst, uint16(len(src.Parameters)))
|
dst = pgio.AppendUint16(dst, uint16(len(src.Parameters)))
|
||||||
for _, p := range src.Parameters {
|
for _, p := range src.Parameters {
|
||||||
if p == nil {
|
if p == nil {
|
||||||
@@ -134,14 +140,15 @@ func (src *Bind) Encode(dst []byte) []byte {
|
|||||||
dst = append(dst, p...)
|
dst = append(dst, p...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(src.ResultFormatCodes) > math.MaxUint16 {
|
||||||
|
return nil, errors.New("too many result format codes")
|
||||||
|
}
|
||||||
dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes)))
|
dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes)))
|
||||||
for _, fc := range src.ResultFormatCodes {
|
for _, fc := range src.ResultFormatCodes {
|
||||||
dst = pgio.AppendInt16(dst, fc)
|
dst = pgio.AppendInt16(dst, fc)
|
||||||
}
|
}
|
||||||
|
|
||||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
return finishMessage(dst, sp)
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
@@ -20,8 +20,8 @@ func (dst *BindComplete) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *BindComplete) Encode(dst []byte) []byte {
|
func (src *BindComplete) Encode(dst []byte) ([]byte, error) {
|
||||||
return append(dst, '2', 0, 0, 0, 4)
|
return append(dst, '2', 0, 0, 0, 4), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
@@ -0,0 +1,20 @@
|
|||||||
|
package pgproto3_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/andoma-go/pgx/v5/pgproto3"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBindBiggerThanMaxMessageBodyLen(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Maximum allowed size.
|
||||||
|
_, err := (&pgproto3.Bind{Parameters: [][]byte{make([]byte, pgproto3.MaxMessageBodyLen-16)}}).Encode(nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 1 byte too big
|
||||||
|
_, err = (&pgproto3.Bind{Parameters: [][]byte{make([]byte, pgproto3.MaxMessageBodyLen-15)}}).Encode(nil)
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
@@ -36,12 +36,12 @@ func (dst *CancelRequest) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 4 byte message length.
|
||||||
func (src *CancelRequest) Encode(dst []byte) []byte {
|
func (src *CancelRequest) Encode(dst []byte) ([]byte, error) {
|
||||||
dst = pgio.AppendInt32(dst, 16)
|
dst = pgio.AppendInt32(dst, 16)
|
||||||
dst = pgio.AppendInt32(dst, cancelRequestCode)
|
dst = pgio.AppendInt32(dst, cancelRequestCode)
|
||||||
dst = pgio.AppendUint32(dst, src.ProcessID)
|
dst = pgio.AppendUint32(dst, src.ProcessID)
|
||||||
dst = pgio.AppendUint32(dst, src.SecretKey)
|
dst = pgio.AppendUint32(dst, src.SecretKey)
|
||||||
return dst
|
return dst, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
+3
-11
@@ -4,8 +4,6 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
"github.com/andoma-go/pgx/v5/internal/pgio"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Close struct {
|
type Close struct {
|
||||||
@@ -37,18 +35,12 @@ func (dst *Close) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *Close) Encode(dst []byte) []byte {
|
func (src *Close) Encode(dst []byte) ([]byte, error) {
|
||||||
dst = append(dst, 'C')
|
dst, sp := beginMessage(dst, 'C')
|
||||||
sp := len(dst)
|
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
|
||||||
|
|
||||||
dst = append(dst, src.ObjectType)
|
dst = append(dst, src.ObjectType)
|
||||||
dst = append(dst, src.Name...)
|
dst = append(dst, src.Name...)
|
||||||
dst = append(dst, 0)
|
dst = append(dst, 0)
|
||||||
|
return finishMessage(dst, sp)
|
||||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
@@ -20,8 +20,8 @@ func (dst *CloseComplete) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *CloseComplete) Encode(dst []byte) []byte {
|
func (src *CloseComplete) Encode(dst []byte) ([]byte, error) {
|
||||||
return append(dst, '3', 0, 0, 0, 4)
|
return append(dst, '3', 0, 0, 0, 4), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
@@ -3,8 +3,6 @@ package pgproto3
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
|
||||||
"github.com/andoma-go/pgx/v5/internal/pgio"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type CommandComplete struct {
|
type CommandComplete struct {
|
||||||
@@ -31,17 +29,11 @@ func (dst *CommandComplete) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *CommandComplete) Encode(dst []byte) []byte {
|
func (src *CommandComplete) Encode(dst []byte) ([]byte, error) {
|
||||||
dst = append(dst, 'C')
|
dst, sp := beginMessage(dst, 'C')
|
||||||
sp := len(dst)
|
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
|
||||||
|
|
||||||
dst = append(dst, src.CommandTag...)
|
dst = append(dst, src.CommandTag...)
|
||||||
dst = append(dst, 0)
|
dst = append(dst, 0)
|
||||||
|
return finishMessage(dst, sp)
|
||||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"math"
|
||||||
|
|
||||||
"github.com/andoma-go/pgx/v5/internal/pgio"
|
"github.com/andoma-go/pgx/v5/internal/pgio"
|
||||||
)
|
)
|
||||||
@@ -44,19 +45,18 @@ func (dst *CopyBothResponse) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *CopyBothResponse) Encode(dst []byte) []byte {
|
func (src *CopyBothResponse) Encode(dst []byte) ([]byte, error) {
|
||||||
dst = append(dst, 'W')
|
dst, sp := beginMessage(dst, 'W')
|
||||||
sp := len(dst)
|
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
|
||||||
dst = append(dst, src.OverallFormat)
|
dst = append(dst, src.OverallFormat)
|
||||||
|
if len(src.ColumnFormatCodes) > math.MaxUint16 {
|
||||||
|
return nil, errors.New("too many column format codes")
|
||||||
|
}
|
||||||
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
|
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
|
||||||
for _, fc := range src.ColumnFormatCodes {
|
for _, fc := range src.ColumnFormatCodes {
|
||||||
dst = pgio.AppendUint16(dst, fc)
|
dst = pgio.AppendUint16(dst, fc)
|
||||||
}
|
}
|
||||||
|
|
||||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
return finishMessage(dst, sp)
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
|
|
||||||
"github.com/andoma-go/pgx/v5/pgproto3"
|
"github.com/andoma-go/pgx/v5/pgproto3"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestEncodeDecode(t *testing.T) {
|
func TestEncodeDecode(t *testing.T) {
|
||||||
@@ -13,6 +14,7 @@ func TestEncodeDecode(t *testing.T) {
|
|||||||
err := dstResp.Decode(srcBytes[5:])
|
err := dstResp.Decode(srcBytes[5:])
|
||||||
assert.NoError(t, err, "No errors on decode")
|
assert.NoError(t, err, "No errors on decode")
|
||||||
dstBytes := []byte{}
|
dstBytes := []byte{}
|
||||||
dstBytes = dstResp.Encode(dstBytes)
|
dstBytes, err = dstResp.Encode(dstBytes)
|
||||||
|
require.NoError(t, err)
|
||||||
assert.EqualValues(t, srcBytes, dstBytes, "Expecting src & dest bytes to match")
|
assert.EqualValues(t, srcBytes, dstBytes, "Expecting src & dest bytes to match")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,8 +3,6 @@ package pgproto3
|
|||||||
import (
|
import (
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
|
||||||
"github.com/andoma-go/pgx/v5/internal/pgio"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type CopyData struct {
|
type CopyData struct {
|
||||||
@@ -25,11 +23,10 @@ func (dst *CopyData) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *CopyData) Encode(dst []byte) []byte {
|
func (src *CopyData) Encode(dst []byte) ([]byte, error) {
|
||||||
dst = append(dst, 'd')
|
dst, sp := beginMessage(dst, 'd')
|
||||||
dst = pgio.AppendInt32(dst, int32(4+len(src.Data)))
|
|
||||||
dst = append(dst, src.Data...)
|
dst = append(dst, src.Data...)
|
||||||
return dst
|
return finishMessage(dst, sp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
@@ -24,8 +24,8 @@ func (dst *CopyDone) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *CopyDone) Encode(dst []byte) []byte {
|
func (src *CopyDone) Encode(dst []byte) ([]byte, error) {
|
||||||
return append(dst, 'c', 0, 0, 0, 4)
|
return append(dst, 'c', 0, 0, 0, 4), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
+3
-11
@@ -3,8 +3,6 @@ package pgproto3
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
|
||||||
"github.com/andoma-go/pgx/v5/internal/pgio"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type CopyFail struct {
|
type CopyFail struct {
|
||||||
@@ -28,17 +26,11 @@ func (dst *CopyFail) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *CopyFail) Encode(dst []byte) []byte {
|
func (src *CopyFail) Encode(dst []byte) ([]byte, error) {
|
||||||
dst = append(dst, 'f')
|
dst, sp := beginMessage(dst, 'f')
|
||||||
sp := len(dst)
|
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
|
||||||
|
|
||||||
dst = append(dst, src.Message...)
|
dst = append(dst, src.Message...)
|
||||||
dst = append(dst, 0)
|
dst = append(dst, 0)
|
||||||
|
return finishMessage(dst, sp)
|
||||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"math"
|
||||||
|
|
||||||
"github.com/andoma-go/pgx/v5/internal/pgio"
|
"github.com/andoma-go/pgx/v5/internal/pgio"
|
||||||
)
|
)
|
||||||
@@ -44,20 +45,19 @@ func (dst *CopyInResponse) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *CopyInResponse) Encode(dst []byte) []byte {
|
func (src *CopyInResponse) Encode(dst []byte) ([]byte, error) {
|
||||||
dst = append(dst, 'G')
|
dst, sp := beginMessage(dst, 'G')
|
||||||
sp := len(dst)
|
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
|
||||||
|
|
||||||
dst = append(dst, src.OverallFormat)
|
dst = append(dst, src.OverallFormat)
|
||||||
|
if len(src.ColumnFormatCodes) > math.MaxUint16 {
|
||||||
|
return nil, errors.New("too many column format codes")
|
||||||
|
}
|
||||||
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
|
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
|
||||||
for _, fc := range src.ColumnFormatCodes {
|
for _, fc := range src.ColumnFormatCodes {
|
||||||
dst = pgio.AppendUint16(dst, fc)
|
dst = pgio.AppendUint16(dst, fc)
|
||||||
}
|
}
|
||||||
|
|
||||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
return finishMessage(dst, sp)
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"math"
|
||||||
|
|
||||||
"github.com/andoma-go/pgx/v5/internal/pgio"
|
"github.com/andoma-go/pgx/v5/internal/pgio"
|
||||||
)
|
)
|
||||||
@@ -43,21 +44,20 @@ func (dst *CopyOutResponse) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *CopyOutResponse) Encode(dst []byte) []byte {
|
func (src *CopyOutResponse) Encode(dst []byte) ([]byte, error) {
|
||||||
dst = append(dst, 'H')
|
dst, sp := beginMessage(dst, 'H')
|
||||||
sp := len(dst)
|
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
|
||||||
|
|
||||||
dst = append(dst, src.OverallFormat)
|
dst = append(dst, src.OverallFormat)
|
||||||
|
|
||||||
|
if len(src.ColumnFormatCodes) > math.MaxUint16 {
|
||||||
|
return nil, errors.New("too many column format codes")
|
||||||
|
}
|
||||||
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
|
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
|
||||||
for _, fc := range src.ColumnFormatCodes {
|
for _, fc := range src.ColumnFormatCodes {
|
||||||
dst = pgio.AppendUint16(dst, fc)
|
dst = pgio.AppendUint16(dst, fc)
|
||||||
}
|
}
|
||||||
|
|
||||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
return finishMessage(dst, sp)
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"math"
|
||||||
|
|
||||||
"github.com/andoma-go/pgx/v5/internal/pgio"
|
"github.com/andoma-go/pgx/v5/internal/pgio"
|
||||||
)
|
)
|
||||||
@@ -63,11 +65,12 @@ func (dst *DataRow) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *DataRow) Encode(dst []byte) []byte {
|
func (src *DataRow) Encode(dst []byte) ([]byte, error) {
|
||||||
dst = append(dst, 'D')
|
dst, sp := beginMessage(dst, 'D')
|
||||||
sp := len(dst)
|
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
|
||||||
|
|
||||||
|
if len(src.Values) > math.MaxUint16 {
|
||||||
|
return nil, errors.New("too many values")
|
||||||
|
}
|
||||||
dst = pgio.AppendUint16(dst, uint16(len(src.Values)))
|
dst = pgio.AppendUint16(dst, uint16(len(src.Values)))
|
||||||
for _, v := range src.Values {
|
for _, v := range src.Values {
|
||||||
if v == nil {
|
if v == nil {
|
||||||
@@ -79,9 +82,7 @@ func (src *DataRow) Encode(dst []byte) []byte {
|
|||||||
dst = append(dst, v...)
|
dst = append(dst, v...)
|
||||||
}
|
}
|
||||||
|
|
||||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
return finishMessage(dst, sp)
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
+3
-11
@@ -4,8 +4,6 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
"github.com/andoma-go/pgx/v5/internal/pgio"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Describe struct {
|
type Describe struct {
|
||||||
@@ -37,18 +35,12 @@ func (dst *Describe) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *Describe) Encode(dst []byte) []byte {
|
func (src *Describe) Encode(dst []byte) ([]byte, error) {
|
||||||
dst = append(dst, 'D')
|
dst, sp := beginMessage(dst, 'D')
|
||||||
sp := len(dst)
|
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
|
||||||
|
|
||||||
dst = append(dst, src.ObjectType)
|
dst = append(dst, src.ObjectType)
|
||||||
dst = append(dst, src.Name...)
|
dst = append(dst, src.Name...)
|
||||||
dst = append(dst, 0)
|
dst = append(dst, 0)
|
||||||
|
return finishMessage(dst, sp)
|
||||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
@@ -20,8 +20,8 @@ func (dst *EmptyQueryResponse) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *EmptyQueryResponse) Encode(dst []byte) []byte {
|
func (src *EmptyQueryResponse) Encode(dst []byte) ([]byte, error) {
|
||||||
return append(dst, 'I', 0, 0, 0, 4)
|
return append(dst, 'I', 0, 0, 0, 4), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
+64
-71
@@ -2,7 +2,6 @@ package pgproto3
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
@@ -111,119 +110,113 @@ func (dst *ErrorResponse) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *ErrorResponse) Encode(dst []byte) []byte {
|
func (src *ErrorResponse) Encode(dst []byte) ([]byte, error) {
|
||||||
return append(dst, src.marshalBinary('E')...)
|
dst, sp := beginMessage(dst, 'E')
|
||||||
|
dst = src.appendFields(dst)
|
||||||
|
return finishMessage(dst, sp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (src *ErrorResponse) marshalBinary(typeByte byte) []byte {
|
func (src *ErrorResponse) appendFields(dst []byte) []byte {
|
||||||
var bigEndian BigEndianBuf
|
|
||||||
buf := &bytes.Buffer{}
|
|
||||||
|
|
||||||
buf.WriteByte(typeByte)
|
|
||||||
buf.Write(bigEndian.Uint32(0))
|
|
||||||
|
|
||||||
if src.Severity != "" {
|
if src.Severity != "" {
|
||||||
buf.WriteByte('S')
|
dst = append(dst, 'S')
|
||||||
buf.WriteString(src.Severity)
|
dst = append(dst, src.Severity...)
|
||||||
buf.WriteByte(0)
|
dst = append(dst, 0)
|
||||||
}
|
}
|
||||||
if src.SeverityUnlocalized != "" {
|
if src.SeverityUnlocalized != "" {
|
||||||
buf.WriteByte('V')
|
dst = append(dst, 'V')
|
||||||
buf.WriteString(src.SeverityUnlocalized)
|
dst = append(dst, src.SeverityUnlocalized...)
|
||||||
buf.WriteByte(0)
|
dst = append(dst, 0)
|
||||||
}
|
}
|
||||||
if src.Code != "" {
|
if src.Code != "" {
|
||||||
buf.WriteByte('C')
|
dst = append(dst, 'C')
|
||||||
buf.WriteString(src.Code)
|
dst = append(dst, src.Code...)
|
||||||
buf.WriteByte(0)
|
dst = append(dst, 0)
|
||||||
}
|
}
|
||||||
if src.Message != "" {
|
if src.Message != "" {
|
||||||
buf.WriteByte('M')
|
dst = append(dst, 'M')
|
||||||
buf.WriteString(src.Message)
|
dst = append(dst, src.Message...)
|
||||||
buf.WriteByte(0)
|
dst = append(dst, 0)
|
||||||
}
|
}
|
||||||
if src.Detail != "" {
|
if src.Detail != "" {
|
||||||
buf.WriteByte('D')
|
dst = append(dst, 'D')
|
||||||
buf.WriteString(src.Detail)
|
dst = append(dst, src.Detail...)
|
||||||
buf.WriteByte(0)
|
dst = append(dst, 0)
|
||||||
}
|
}
|
||||||
if src.Hint != "" {
|
if src.Hint != "" {
|
||||||
buf.WriteByte('H')
|
dst = append(dst, 'H')
|
||||||
buf.WriteString(src.Hint)
|
dst = append(dst, src.Hint...)
|
||||||
buf.WriteByte(0)
|
dst = append(dst, 0)
|
||||||
}
|
}
|
||||||
if src.Position != 0 {
|
if src.Position != 0 {
|
||||||
buf.WriteByte('P')
|
dst = append(dst, 'P')
|
||||||
buf.WriteString(strconv.Itoa(int(src.Position)))
|
dst = append(dst, strconv.Itoa(int(src.Position))...)
|
||||||
buf.WriteByte(0)
|
dst = append(dst, 0)
|
||||||
}
|
}
|
||||||
if src.InternalPosition != 0 {
|
if src.InternalPosition != 0 {
|
||||||
buf.WriteByte('p')
|
dst = append(dst, 'p')
|
||||||
buf.WriteString(strconv.Itoa(int(src.InternalPosition)))
|
dst = append(dst, strconv.Itoa(int(src.InternalPosition))...)
|
||||||
buf.WriteByte(0)
|
dst = append(dst, 0)
|
||||||
}
|
}
|
||||||
if src.InternalQuery != "" {
|
if src.InternalQuery != "" {
|
||||||
buf.WriteByte('q')
|
dst = append(dst, 'q')
|
||||||
buf.WriteString(src.InternalQuery)
|
dst = append(dst, src.InternalQuery...)
|
||||||
buf.WriteByte(0)
|
dst = append(dst, 0)
|
||||||
}
|
}
|
||||||
if src.Where != "" {
|
if src.Where != "" {
|
||||||
buf.WriteByte('W')
|
dst = append(dst, 'W')
|
||||||
buf.WriteString(src.Where)
|
dst = append(dst, src.Where...)
|
||||||
buf.WriteByte(0)
|
dst = append(dst, 0)
|
||||||
}
|
}
|
||||||
if src.SchemaName != "" {
|
if src.SchemaName != "" {
|
||||||
buf.WriteByte('s')
|
dst = append(dst, 's')
|
||||||
buf.WriteString(src.SchemaName)
|
dst = append(dst, src.SchemaName...)
|
||||||
buf.WriteByte(0)
|
dst = append(dst, 0)
|
||||||
}
|
}
|
||||||
if src.TableName != "" {
|
if src.TableName != "" {
|
||||||
buf.WriteByte('t')
|
dst = append(dst, 't')
|
||||||
buf.WriteString(src.TableName)
|
dst = append(dst, src.TableName...)
|
||||||
buf.WriteByte(0)
|
dst = append(dst, 0)
|
||||||
}
|
}
|
||||||
if src.ColumnName != "" {
|
if src.ColumnName != "" {
|
||||||
buf.WriteByte('c')
|
dst = append(dst, 'c')
|
||||||
buf.WriteString(src.ColumnName)
|
dst = append(dst, src.ColumnName...)
|
||||||
buf.WriteByte(0)
|
dst = append(dst, 0)
|
||||||
}
|
}
|
||||||
if src.DataTypeName != "" {
|
if src.DataTypeName != "" {
|
||||||
buf.WriteByte('d')
|
dst = append(dst, 'd')
|
||||||
buf.WriteString(src.DataTypeName)
|
dst = append(dst, src.DataTypeName...)
|
||||||
buf.WriteByte(0)
|
dst = append(dst, 0)
|
||||||
}
|
}
|
||||||
if src.ConstraintName != "" {
|
if src.ConstraintName != "" {
|
||||||
buf.WriteByte('n')
|
dst = append(dst, 'n')
|
||||||
buf.WriteString(src.ConstraintName)
|
dst = append(dst, src.ConstraintName...)
|
||||||
buf.WriteByte(0)
|
dst = append(dst, 0)
|
||||||
}
|
}
|
||||||
if src.File != "" {
|
if src.File != "" {
|
||||||
buf.WriteByte('F')
|
dst = append(dst, 'F')
|
||||||
buf.WriteString(src.File)
|
dst = append(dst, src.File...)
|
||||||
buf.WriteByte(0)
|
dst = append(dst, 0)
|
||||||
}
|
}
|
||||||
if src.Line != 0 {
|
if src.Line != 0 {
|
||||||
buf.WriteByte('L')
|
dst = append(dst, 'L')
|
||||||
buf.WriteString(strconv.Itoa(int(src.Line)))
|
dst = append(dst, strconv.Itoa(int(src.Line))...)
|
||||||
buf.WriteByte(0)
|
dst = append(dst, 0)
|
||||||
}
|
}
|
||||||
if src.Routine != "" {
|
if src.Routine != "" {
|
||||||
buf.WriteByte('R')
|
dst = append(dst, 'R')
|
||||||
buf.WriteString(src.Routine)
|
dst = append(dst, src.Routine...)
|
||||||
buf.WriteByte(0)
|
dst = append(dst, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
for k, v := range src.UnknownFields {
|
for k, v := range src.UnknownFields {
|
||||||
buf.WriteByte(k)
|
dst = append(dst, k)
|
||||||
buf.WriteString(v)
|
dst = append(dst, v...)
|
||||||
buf.WriteByte(0)
|
dst = append(dst, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
buf.WriteByte(0)
|
dst = append(dst, 0)
|
||||||
|
|
||||||
binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1))
|
return dst
|
||||||
|
|
||||||
return buf.Bytes()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ func (p *PgFortuneBackend) Run() error {
|
|||||||
return fmt.Errorf("error generating query response: %w", err)
|
return fmt.Errorf("error generating query response: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
buf := (&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{
|
buf := mustEncode((&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{
|
||||||
{
|
{
|
||||||
Name: []byte("fortune"),
|
Name: []byte("fortune"),
|
||||||
TableOID: 0,
|
TableOID: 0,
|
||||||
@@ -56,10 +56,10 @@ func (p *PgFortuneBackend) Run() error {
|
|||||||
TypeModifier: -1,
|
TypeModifier: -1,
|
||||||
Format: 0,
|
Format: 0,
|
||||||
},
|
},
|
||||||
}}).Encode(nil)
|
}}).Encode(nil))
|
||||||
buf = (&pgproto3.DataRow{Values: [][]byte{response}}).Encode(buf)
|
buf = mustEncode((&pgproto3.DataRow{Values: [][]byte{response}}).Encode(buf))
|
||||||
buf = (&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}).Encode(buf)
|
buf = mustEncode((&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}).Encode(buf))
|
||||||
buf = (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf)
|
buf = mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf))
|
||||||
_, err = p.conn.Write(buf)
|
_, err = p.conn.Write(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error writing query response: %w", err)
|
return fmt.Errorf("error writing query response: %w", err)
|
||||||
@@ -80,8 +80,8 @@ func (p *PgFortuneBackend) handleStartup() error {
|
|||||||
|
|
||||||
switch startupMessage.(type) {
|
switch startupMessage.(type) {
|
||||||
case *pgproto3.StartupMessage:
|
case *pgproto3.StartupMessage:
|
||||||
buf := (&pgproto3.AuthenticationOk{}).Encode(nil)
|
buf := mustEncode((&pgproto3.AuthenticationOk{}).Encode(nil))
|
||||||
buf = (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf)
|
buf = mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf))
|
||||||
_, err = p.conn.Write(buf)
|
_, err = p.conn.Write(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error sending ready for query: %w", err)
|
return fmt.Errorf("error sending ready for query: %w", err)
|
||||||
@@ -102,3 +102,10 @@ func (p *PgFortuneBackend) handleStartup() error {
|
|||||||
func (p *PgFortuneBackend) Close() error {
|
func (p *PgFortuneBackend) Close() error {
|
||||||
return p.conn.Close()
|
return p.conn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func mustEncode(buf []byte, err error) []byte {
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|||||||
+3
-10
@@ -36,19 +36,12 @@ func (dst *Execute) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *Execute) Encode(dst []byte) []byte {
|
func (src *Execute) Encode(dst []byte) ([]byte, error) {
|
||||||
dst = append(dst, 'E')
|
dst, sp := beginMessage(dst, 'E')
|
||||||
sp := len(dst)
|
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
|
||||||
|
|
||||||
dst = append(dst, src.Portal...)
|
dst = append(dst, src.Portal...)
|
||||||
dst = append(dst, 0)
|
dst = append(dst, 0)
|
||||||
|
|
||||||
dst = pgio.AppendUint32(dst, src.MaxRows)
|
dst = pgio.AppendUint32(dst, src.MaxRows)
|
||||||
|
return finishMessage(dst, sp)
|
||||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
+2
-2
@@ -20,8 +20,8 @@ func (dst *Flush) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *Flush) Encode(dst []byte) []byte {
|
func (src *Flush) Encode(dst []byte) ([]byte, error) {
|
||||||
return append(dst, 'H', 0, 0, 0, 4)
|
return append(dst, 'H', 0, 0, 0, 4), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
+112
-25
@@ -18,7 +18,8 @@ type Frontend struct {
|
|||||||
// idle. Setting and unsetting tracer provides equivalent functionality to PQtrace and PQuntrace in libpq.
|
// idle. Setting and unsetting tracer provides equivalent functionality to PQtrace and PQuntrace in libpq.
|
||||||
tracer *tracer
|
tracer *tracer
|
||||||
|
|
||||||
wbuf []byte
|
wbuf []byte
|
||||||
|
encodeError error
|
||||||
|
|
||||||
// Backend message flyweights
|
// Backend message flyweights
|
||||||
authenticationOk AuthenticationOk
|
authenticationOk AuthenticationOk
|
||||||
@@ -64,16 +65,26 @@ func NewFrontend(r io.Reader, w io.Writer) *Frontend {
|
|||||||
return &Frontend{cr: cr, w: w}
|
return &Frontend{cr: cr, w: w}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send sends a message to the backend (i.e. the server). The message is not guaranteed to be written until Flush is
|
// Send sends a message to the backend (i.e. the server). The message is buffered until Flush is called. Any error
|
||||||
// called.
|
// encountered will be returned from Flush.
|
||||||
//
|
//
|
||||||
// Send can work with any FrontendMessage. Some commonly used message types such as Bind have specialized send methods
|
// Send can work with any FrontendMessage. Some commonly used message types such as Bind have specialized send methods
|
||||||
// such as SendBind. These methods should be preferred when the type of message is known up front (e.g. when building an
|
// such as SendBind. These methods should be preferred when the type of message is known up front (e.g. when building an
|
||||||
// extended query protocol query) as they may be faster due to knowing the type of msg rather than it being hidden
|
// extended query protocol query) as they may be faster due to knowing the type of msg rather than it being hidden
|
||||||
// behind an interface.
|
// behind an interface.
|
||||||
func (f *Frontend) Send(msg FrontendMessage) {
|
func (f *Frontend) Send(msg FrontendMessage) {
|
||||||
|
if f.encodeError != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
prevLen := len(f.wbuf)
|
prevLen := len(f.wbuf)
|
||||||
f.wbuf = msg.Encode(f.wbuf)
|
newBuf, err := msg.Encode(f.wbuf)
|
||||||
|
if err != nil {
|
||||||
|
f.encodeError = err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f.wbuf = newBuf
|
||||||
|
|
||||||
if f.tracer != nil {
|
if f.tracer != nil {
|
||||||
f.tracer.traceMessage('F', int32(len(f.wbuf)-prevLen), msg)
|
f.tracer.traceMessage('F', int32(len(f.wbuf)-prevLen), msg)
|
||||||
}
|
}
|
||||||
@@ -81,6 +92,12 @@ func (f *Frontend) Send(msg FrontendMessage) {
|
|||||||
|
|
||||||
// Flush writes any pending messages to the backend (i.e. the server).
|
// Flush writes any pending messages to the backend (i.e. the server).
|
||||||
func (f *Frontend) Flush() error {
|
func (f *Frontend) Flush() error {
|
||||||
|
if err := f.encodeError; err != nil {
|
||||||
|
f.encodeError = nil
|
||||||
|
f.wbuf = f.wbuf[:0]
|
||||||
|
return &writeError{err: err, safeToRetry: true}
|
||||||
|
}
|
||||||
|
|
||||||
if len(f.wbuf) == 0 {
|
if len(f.wbuf) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -116,71 +133,141 @@ func (f *Frontend) Untrace() {
|
|||||||
f.tracer = nil
|
f.tracer = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendBind sends a Bind message to the backend (i.e. the server). The message is not guaranteed to be written until
|
// SendBind sends a Bind message to the backend (i.e. the server). The message is buffered until Flush is called. Any
|
||||||
// Flush is called.
|
// error encountered will be returned from Flush.
|
||||||
func (f *Frontend) SendBind(msg *Bind) {
|
func (f *Frontend) SendBind(msg *Bind) {
|
||||||
|
if f.encodeError != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
prevLen := len(f.wbuf)
|
prevLen := len(f.wbuf)
|
||||||
f.wbuf = msg.Encode(f.wbuf)
|
newBuf, err := msg.Encode(f.wbuf)
|
||||||
|
if err != nil {
|
||||||
|
f.encodeError = err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f.wbuf = newBuf
|
||||||
|
|
||||||
if f.tracer != nil {
|
if f.tracer != nil {
|
||||||
f.tracer.traceBind('F', int32(len(f.wbuf)-prevLen), msg)
|
f.tracer.traceBind('F', int32(len(f.wbuf)-prevLen), msg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendParse sends a Parse message to the backend (i.e. the server). The message is not guaranteed to be written until
|
// SendParse sends a Parse message to the backend (i.e. the server). The message is buffered until Flush is called. Any
|
||||||
// Flush is called.
|
// error encountered will be returned from Flush.
|
||||||
func (f *Frontend) SendParse(msg *Parse) {
|
func (f *Frontend) SendParse(msg *Parse) {
|
||||||
|
if f.encodeError != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
prevLen := len(f.wbuf)
|
prevLen := len(f.wbuf)
|
||||||
f.wbuf = msg.Encode(f.wbuf)
|
newBuf, err := msg.Encode(f.wbuf)
|
||||||
|
if err != nil {
|
||||||
|
f.encodeError = err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f.wbuf = newBuf
|
||||||
|
|
||||||
if f.tracer != nil {
|
if f.tracer != nil {
|
||||||
f.tracer.traceParse('F', int32(len(f.wbuf)-prevLen), msg)
|
f.tracer.traceParse('F', int32(len(f.wbuf)-prevLen), msg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendClose sends a Close message to the backend (i.e. the server). The message is not guaranteed to be written until
|
// SendClose sends a Close message to the backend (i.e. the server). The message is buffered until Flush is called. Any
|
||||||
// Flush is called.
|
// error encountered will be returned from Flush.
|
||||||
func (f *Frontend) SendClose(msg *Close) {
|
func (f *Frontend) SendClose(msg *Close) {
|
||||||
|
if f.encodeError != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
prevLen := len(f.wbuf)
|
prevLen := len(f.wbuf)
|
||||||
f.wbuf = msg.Encode(f.wbuf)
|
newBuf, err := msg.Encode(f.wbuf)
|
||||||
|
if err != nil {
|
||||||
|
f.encodeError = err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f.wbuf = newBuf
|
||||||
|
|
||||||
if f.tracer != nil {
|
if f.tracer != nil {
|
||||||
f.tracer.traceClose('F', int32(len(f.wbuf)-prevLen), msg)
|
f.tracer.traceClose('F', int32(len(f.wbuf)-prevLen), msg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendDescribe sends a Describe message to the backend (i.e. the server). The message is not guaranteed to be written until
|
// SendDescribe sends a Describe message to the backend (i.e. the server). The message is buffered until Flush is
|
||||||
// Flush is called.
|
// called. Any error encountered will be returned from Flush.
|
||||||
func (f *Frontend) SendDescribe(msg *Describe) {
|
func (f *Frontend) SendDescribe(msg *Describe) {
|
||||||
|
if f.encodeError != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
prevLen := len(f.wbuf)
|
prevLen := len(f.wbuf)
|
||||||
f.wbuf = msg.Encode(f.wbuf)
|
newBuf, err := msg.Encode(f.wbuf)
|
||||||
|
if err != nil {
|
||||||
|
f.encodeError = err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f.wbuf = newBuf
|
||||||
|
|
||||||
if f.tracer != nil {
|
if f.tracer != nil {
|
||||||
f.tracer.traceDescribe('F', int32(len(f.wbuf)-prevLen), msg)
|
f.tracer.traceDescribe('F', int32(len(f.wbuf)-prevLen), msg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendExecute sends an Execute message to the backend (i.e. the server). The message is not guaranteed to be written until
|
// SendExecute sends an Execute message to the backend (i.e. the server). The message is buffered until Flush is called.
|
||||||
// Flush is called.
|
// Any error encountered will be returned from Flush.
|
||||||
func (f *Frontend) SendExecute(msg *Execute) {
|
func (f *Frontend) SendExecute(msg *Execute) {
|
||||||
|
if f.encodeError != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
prevLen := len(f.wbuf)
|
prevLen := len(f.wbuf)
|
||||||
f.wbuf = msg.Encode(f.wbuf)
|
newBuf, err := msg.Encode(f.wbuf)
|
||||||
|
if err != nil {
|
||||||
|
f.encodeError = err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f.wbuf = newBuf
|
||||||
|
|
||||||
if f.tracer != nil {
|
if f.tracer != nil {
|
||||||
f.tracer.TraceQueryute('F', int32(len(f.wbuf)-prevLen), msg)
|
f.tracer.TraceQueryute('F', int32(len(f.wbuf)-prevLen), msg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendSync sends a Sync message to the backend (i.e. the server). The message is not guaranteed to be written until
|
// SendSync sends a Sync message to the backend (i.e. the server). The message is buffered until Flush is called. Any
|
||||||
// Flush is called.
|
// error encountered will be returned from Flush.
|
||||||
func (f *Frontend) SendSync(msg *Sync) {
|
func (f *Frontend) SendSync(msg *Sync) {
|
||||||
|
if f.encodeError != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
prevLen := len(f.wbuf)
|
prevLen := len(f.wbuf)
|
||||||
f.wbuf = msg.Encode(f.wbuf)
|
newBuf, err := msg.Encode(f.wbuf)
|
||||||
|
if err != nil {
|
||||||
|
f.encodeError = err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f.wbuf = newBuf
|
||||||
|
|
||||||
if f.tracer != nil {
|
if f.tracer != nil {
|
||||||
f.tracer.traceSync('F', int32(len(f.wbuf)-prevLen), msg)
|
f.tracer.traceSync('F', int32(len(f.wbuf)-prevLen), msg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendQuery sends a Query message to the backend (i.e. the server). The message is not guaranteed to be written until
|
// SendQuery sends a Query message to the backend (i.e. the server). The message is buffered until Flush is called. Any
|
||||||
// Flush is called.
|
// error encountered will be returned from Flush.
|
||||||
func (f *Frontend) SendQuery(msg *Query) {
|
func (f *Frontend) SendQuery(msg *Query) {
|
||||||
|
if f.encodeError != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
prevLen := len(f.wbuf)
|
prevLen := len(f.wbuf)
|
||||||
f.wbuf = msg.Encode(f.wbuf)
|
newBuf, err := msg.Encode(f.wbuf)
|
||||||
|
if err != nil {
|
||||||
|
f.encodeError = err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f.wbuf = newBuf
|
||||||
|
|
||||||
if f.tracer != nil {
|
if f.tracer != nil {
|
||||||
f.tracer.traceQuery('F', int32(len(f.wbuf)-prevLen), msg)
|
f.tracer.traceQuery('F', int32(len(f.wbuf)-prevLen), msg)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ package pgproto3
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"math"
|
||||||
|
|
||||||
"github.com/andoma-go/pgx/v5/internal/pgio"
|
"github.com/andoma-go/pgx/v5/internal/pgio"
|
||||||
)
|
)
|
||||||
@@ -71,15 +73,21 @@ func (dst *FunctionCall) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *FunctionCall) Encode(dst []byte) []byte {
|
func (src *FunctionCall) Encode(dst []byte) ([]byte, error) {
|
||||||
dst = append(dst, 'F')
|
dst, sp := beginMessage(dst, 'F')
|
||||||
sp := len(dst)
|
|
||||||
dst = pgio.AppendUint32(dst, 0) // Unknown length, set it at the end
|
|
||||||
dst = pgio.AppendUint32(dst, src.Function)
|
dst = pgio.AppendUint32(dst, src.Function)
|
||||||
|
|
||||||
|
if len(src.ArgFormatCodes) > math.MaxUint16 {
|
||||||
|
return nil, errors.New("too many arg format codes")
|
||||||
|
}
|
||||||
dst = pgio.AppendUint16(dst, uint16(len(src.ArgFormatCodes)))
|
dst = pgio.AppendUint16(dst, uint16(len(src.ArgFormatCodes)))
|
||||||
for _, argFormatCode := range src.ArgFormatCodes {
|
for _, argFormatCode := range src.ArgFormatCodes {
|
||||||
dst = pgio.AppendUint16(dst, argFormatCode)
|
dst = pgio.AppendUint16(dst, argFormatCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(src.Arguments) > math.MaxUint16 {
|
||||||
|
return nil, errors.New("too many arguments")
|
||||||
|
}
|
||||||
dst = pgio.AppendUint16(dst, uint16(len(src.Arguments)))
|
dst = pgio.AppendUint16(dst, uint16(len(src.Arguments)))
|
||||||
for _, argument := range src.Arguments {
|
for _, argument := range src.Arguments {
|
||||||
if argument == nil {
|
if argument == nil {
|
||||||
@@ -90,6 +98,5 @@ func (src *FunctionCall) Encode(dst []byte) []byte {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
dst = pgio.AppendUint16(dst, src.ResultFormatCode)
|
dst = pgio.AppendUint16(dst, src.ResultFormatCode)
|
||||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
return finishMessage(dst, sp)
|
||||||
return dst
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,10 +39,8 @@ func (dst *FunctionCallResponse) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *FunctionCallResponse) Encode(dst []byte) []byte {
|
func (src *FunctionCallResponse) Encode(dst []byte) ([]byte, error) {
|
||||||
dst = append(dst, 'V')
|
dst, sp := beginMessage(dst, 'V')
|
||||||
sp := len(dst)
|
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
|
||||||
|
|
||||||
if src.Result == nil {
|
if src.Result == nil {
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
@@ -51,9 +49,7 @@ func (src *FunctionCallResponse) Encode(dst []byte) []byte {
|
|||||||
dst = append(dst, src.Result...)
|
dst = append(dst, src.Result...)
|
||||||
}
|
}
|
||||||
|
|
||||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
return finishMessage(dst, sp)
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestFunctionCall_EncodeDecode(t *testing.T) {
|
func TestFunctionCall_EncodeDecode(t *testing.T) {
|
||||||
@@ -30,7 +32,8 @@ func TestFunctionCall_EncodeDecode(t *testing.T) {
|
|||||||
Arguments: tt.fields.Arguments,
|
Arguments: tt.fields.Arguments,
|
||||||
ResultFormatCode: tt.fields.ResultFormatCode,
|
ResultFormatCode: tt.fields.ResultFormatCode,
|
||||||
}
|
}
|
||||||
encoded := src.Encode([]byte{})
|
encoded, err := src.Encode([]byte{})
|
||||||
|
require.NoError(t, err)
|
||||||
dst := &FunctionCall{}
|
dst := &FunctionCall{}
|
||||||
// Check the header
|
// Check the header
|
||||||
msgTypeCode := encoded[0]
|
msgTypeCode := encoded[0]
|
||||||
@@ -44,7 +47,7 @@ func TestFunctionCall_EncodeDecode(t *testing.T) {
|
|||||||
t.Errorf("Incorrect message length, got = %v, wanted = %v", l, len(encoded))
|
t.Errorf("Incorrect message length, got = %v, wanted = %v", l, len(encoded))
|
||||||
}
|
}
|
||||||
// Check decoding works as expected
|
// Check decoding works as expected
|
||||||
err := dst.Decode(encoded[5:])
|
err = dst.Decode(encoded[5:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !tt.wantErr {
|
if !tt.wantErr {
|
||||||
t.Errorf("FunctionCall.Decode() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("FunctionCall.Decode() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
|||||||
@@ -31,10 +31,10 @@ func (dst *GSSEncRequest) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 4 byte message length.
|
||||||
func (src *GSSEncRequest) Encode(dst []byte) []byte {
|
func (src *GSSEncRequest) Encode(dst []byte) ([]byte, error) {
|
||||||
dst = pgio.AppendInt32(dst, 8)
|
dst = pgio.AppendInt32(dst, 8)
|
||||||
dst = pgio.AppendInt32(dst, gssEncReqNumber)
|
dst = pgio.AppendInt32(dst, gssEncReqNumber)
|
||||||
return dst
|
return dst, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
@@ -2,8 +2,6 @@ package pgproto3
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
|
||||||
"github.com/andoma-go/pgx/v5/internal/pgio"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type GSSResponse struct {
|
type GSSResponse struct {
|
||||||
@@ -18,11 +16,10 @@ func (g *GSSResponse) Decode(data []byte) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GSSResponse) Encode(dst []byte) []byte {
|
func (g *GSSResponse) Encode(dst []byte) ([]byte, error) {
|
||||||
dst = append(dst, 'p')
|
dst, sp := beginMessage(dst, 'p')
|
||||||
dst = pgio.AppendInt32(dst, int32(4+len(g.Data)))
|
|
||||||
dst = append(dst, g.Data...)
|
dst = append(dst, g.Data...)
|
||||||
return dst
|
return finishMessage(dst, sp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
+2
-2
@@ -20,8 +20,8 @@ func (dst *NoData) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *NoData) Encode(dst []byte) []byte {
|
func (src *NoData) Encode(dst []byte) ([]byte, error) {
|
||||||
return append(dst, 'n', 0, 0, 0, 4)
|
return append(dst, 'n', 0, 0, 0, 4), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ func (dst *NoticeResponse) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *NoticeResponse) Encode(dst []byte) []byte {
|
func (src *NoticeResponse) Encode(dst []byte) ([]byte, error) {
|
||||||
return append(dst, (*ErrorResponse)(src).marshalBinary('N')...)
|
dst, sp := beginMessage(dst, 'N')
|
||||||
|
dst = (*ErrorResponse)(src).appendFields(dst)
|
||||||
|
return finishMessage(dst, sp)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -45,20 +45,14 @@ func (dst *NotificationResponse) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *NotificationResponse) Encode(dst []byte) []byte {
|
func (src *NotificationResponse) Encode(dst []byte) ([]byte, error) {
|
||||||
dst = append(dst, 'A')
|
dst, sp := beginMessage(dst, 'A')
|
||||||
sp := len(dst)
|
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
|
||||||
|
|
||||||
dst = pgio.AppendUint32(dst, src.PID)
|
dst = pgio.AppendUint32(dst, src.PID)
|
||||||
dst = append(dst, src.Channel...)
|
dst = append(dst, src.Channel...)
|
||||||
dst = append(dst, 0)
|
dst = append(dst, 0)
|
||||||
dst = append(dst, src.Payload...)
|
dst = append(dst, src.Payload...)
|
||||||
dst = append(dst, 0)
|
dst = append(dst, 0)
|
||||||
|
return finishMessage(dst, sp)
|
||||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"math"
|
||||||
|
|
||||||
"github.com/andoma-go/pgx/v5/internal/pgio"
|
"github.com/andoma-go/pgx/v5/internal/pgio"
|
||||||
)
|
)
|
||||||
@@ -39,19 +41,18 @@ func (dst *ParameterDescription) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *ParameterDescription) Encode(dst []byte) []byte {
|
func (src *ParameterDescription) Encode(dst []byte) ([]byte, error) {
|
||||||
dst = append(dst, 't')
|
dst, sp := beginMessage(dst, 't')
|
||||||
sp := len(dst)
|
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
|
||||||
|
|
||||||
|
if len(src.ParameterOIDs) > math.MaxUint16 {
|
||||||
|
return nil, errors.New("too many parameter oids")
|
||||||
|
}
|
||||||
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs)))
|
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs)))
|
||||||
for _, oid := range src.ParameterOIDs {
|
for _, oid := range src.ParameterOIDs {
|
||||||
dst = pgio.AppendUint32(dst, oid)
|
dst = pgio.AppendUint32(dst, oid)
|
||||||
}
|
}
|
||||||
|
|
||||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
return finishMessage(dst, sp)
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
@@ -3,8 +3,6 @@ package pgproto3
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
|
||||||
"github.com/andoma-go/pgx/v5/internal/pgio"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type ParameterStatus struct {
|
type ParameterStatus struct {
|
||||||
@@ -37,19 +35,13 @@ func (dst *ParameterStatus) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *ParameterStatus) Encode(dst []byte) []byte {
|
func (src *ParameterStatus) Encode(dst []byte) ([]byte, error) {
|
||||||
dst = append(dst, 'S')
|
dst, sp := beginMessage(dst, 'S')
|
||||||
sp := len(dst)
|
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
|
||||||
|
|
||||||
dst = append(dst, src.Name...)
|
dst = append(dst, src.Name...)
|
||||||
dst = append(dst, 0)
|
dst = append(dst, 0)
|
||||||
dst = append(dst, src.Value...)
|
dst = append(dst, src.Value...)
|
||||||
dst = append(dst, 0)
|
dst = append(dst, 0)
|
||||||
|
return finishMessage(dst, sp)
|
||||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
+8
-7
@@ -4,6 +4,8 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"math"
|
||||||
|
|
||||||
"github.com/andoma-go/pgx/v5/internal/pgio"
|
"github.com/andoma-go/pgx/v5/internal/pgio"
|
||||||
)
|
)
|
||||||
@@ -52,24 +54,23 @@ func (dst *Parse) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *Parse) Encode(dst []byte) []byte {
|
func (src *Parse) Encode(dst []byte) ([]byte, error) {
|
||||||
dst = append(dst, 'P')
|
dst, sp := beginMessage(dst, 'P')
|
||||||
sp := len(dst)
|
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
|
||||||
|
|
||||||
dst = append(dst, src.Name...)
|
dst = append(dst, src.Name...)
|
||||||
dst = append(dst, 0)
|
dst = append(dst, 0)
|
||||||
dst = append(dst, src.Query...)
|
dst = append(dst, src.Query...)
|
||||||
dst = append(dst, 0)
|
dst = append(dst, 0)
|
||||||
|
|
||||||
|
if len(src.ParameterOIDs) > math.MaxUint16 {
|
||||||
|
return nil, errors.New("too many parameter oids")
|
||||||
|
}
|
||||||
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs)))
|
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs)))
|
||||||
for _, oid := range src.ParameterOIDs {
|
for _, oid := range src.ParameterOIDs {
|
||||||
dst = pgio.AppendUint32(dst, oid)
|
dst = pgio.AppendUint32(dst, oid)
|
||||||
}
|
}
|
||||||
|
|
||||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
return finishMessage(dst, sp)
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
@@ -20,8 +20,8 @@ func (dst *ParseComplete) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *ParseComplete) Encode(dst []byte) []byte {
|
func (src *ParseComplete) Encode(dst []byte) ([]byte, error) {
|
||||||
return append(dst, '1', 0, 0, 0, 4)
|
return append(dst, '1', 0, 0, 0, 4), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
@@ -3,8 +3,6 @@ package pgproto3
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
|
||||||
"github.com/andoma-go/pgx/v5/internal/pgio"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type PasswordMessage struct {
|
type PasswordMessage struct {
|
||||||
@@ -32,14 +30,11 @@ func (dst *PasswordMessage) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *PasswordMessage) Encode(dst []byte) []byte {
|
func (src *PasswordMessage) Encode(dst []byte) ([]byte, error) {
|
||||||
dst = append(dst, 'p')
|
dst, sp := beginMessage(dst, 'p')
|
||||||
dst = pgio.AppendInt32(dst, int32(4+len(src.Password)+1))
|
|
||||||
|
|
||||||
dst = append(dst, src.Password...)
|
dst = append(dst, src.Password...)
|
||||||
dst = append(dst, 0)
|
dst = append(dst, 0)
|
||||||
|
return finishMessage(dst, sp)
|
||||||
return dst
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
+27
-1
@@ -4,8 +4,14 @@ import (
|
|||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/andoma-go/pgx/v5/internal/pgio"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// maxMessageBodyLen is the maximum length of a message body in bytes. See PG_LARGE_MESSAGE_LIMIT in the PostgreSQL
|
||||||
|
// source. It is defined as (MaxAllocSize - 1). MaxAllocSize is defined as 0x3fffffff.
|
||||||
|
const maxMessageBodyLen = (0x3fffffff - 1)
|
||||||
|
|
||||||
// Message is the interface implemented by an object that can decode and encode
|
// Message is the interface implemented by an object that can decode and encode
|
||||||
// a particular PostgreSQL message.
|
// a particular PostgreSQL message.
|
||||||
type Message interface {
|
type Message interface {
|
||||||
@@ -14,7 +20,7 @@ type Message interface {
|
|||||||
Decode(data []byte) error
|
Decode(data []byte) error
|
||||||
|
|
||||||
// Encode appends itself to dst and returns the new buffer.
|
// Encode appends itself to dst and returns the new buffer.
|
||||||
Encode(dst []byte) []byte
|
Encode(dst []byte) ([]byte, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// FrontendMessage is a message sent by the frontend (i.e. the client).
|
// FrontendMessage is a message sent by the frontend (i.e. the client).
|
||||||
@@ -92,3 +98,23 @@ func getValueFromJSON(v map[string]string) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
return nil, errors.New("unknown protocol representation")
|
return nil, errors.New("unknown protocol representation")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// beginMessage begines a new message of type t. It appends the message type and a placeholder for the message length to
|
||||||
|
// dst. It returns the new buffer and the position of the message length placeholder.
|
||||||
|
func beginMessage(dst []byte, t byte) ([]byte, int) {
|
||||||
|
dst = append(dst, t)
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
return dst, sp
|
||||||
|
}
|
||||||
|
|
||||||
|
// finishMessage finishes a message that was started with beginMessage. It computes the message length and writes it to
|
||||||
|
// dst[sp]. If the message length is too large it returns an error. Otherwise it returns the final message buffer.
|
||||||
|
func finishMessage(dst []byte, sp int) ([]byte, error) {
|
||||||
|
messageBodyLen := len(dst[sp:])
|
||||||
|
if messageBodyLen > maxMessageBodyLen {
|
||||||
|
return nil, errors.New("message body too large")
|
||||||
|
}
|
||||||
|
pgio.SetInt32(dst[sp:], int32(messageBodyLen))
|
||||||
|
return dst, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,3 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
const MaxMessageBodyLen = maxMessageBodyLen
|
||||||
@@ -20,8 +20,8 @@ func (dst *PortalSuspended) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *PortalSuspended) Encode(dst []byte) []byte {
|
func (src *PortalSuspended) Encode(dst []byte) ([]byte, error) {
|
||||||
return append(dst, 's', 0, 0, 0, 4)
|
return append(dst, 's', 0, 0, 0, 4), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
+3
-8
@@ -3,8 +3,6 @@ package pgproto3
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
|
||||||
"github.com/andoma-go/pgx/v5/internal/pgio"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Query struct {
|
type Query struct {
|
||||||
@@ -28,14 +26,11 @@ func (dst *Query) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *Query) Encode(dst []byte) []byte {
|
func (src *Query) Encode(dst []byte) ([]byte, error) {
|
||||||
dst = append(dst, 'Q')
|
dst, sp := beginMessage(dst, 'Q')
|
||||||
dst = pgio.AppendInt32(dst, int32(4+len(src.String)+1))
|
|
||||||
|
|
||||||
dst = append(dst, src.String...)
|
dst = append(dst, src.String...)
|
||||||
dst = append(dst, 0)
|
dst = append(dst, 0)
|
||||||
|
return finishMessage(dst, sp)
|
||||||
return dst
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
@@ -0,0 +1,20 @@
|
|||||||
|
package pgproto3_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/andoma-go/pgx/v5/pgproto3"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestQueryBiggerThanMaxMessageBodyLen(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Maximum allowed size. 4 bytes for size and 1 byte for 0 terminated string.
|
||||||
|
_, err := (&pgproto3.Query{String: string(make([]byte, pgproto3.MaxMessageBodyLen-5))}).Encode(nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 1 byte too big
|
||||||
|
_, err = (&pgproto3.Query{String: string(make([]byte, pgproto3.MaxMessageBodyLen-4))}).Encode(nil)
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
@@ -25,8 +25,8 @@ func (dst *ReadyForQuery) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *ReadyForQuery) Encode(dst []byte) []byte {
|
func (src *ReadyForQuery) Encode(dst []byte) ([]byte, error) {
|
||||||
return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus)
|
return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"math"
|
||||||
|
|
||||||
"github.com/andoma-go/pgx/v5/internal/pgio"
|
"github.com/andoma-go/pgx/v5/internal/pgio"
|
||||||
)
|
)
|
||||||
@@ -99,11 +101,12 @@ func (dst *RowDescription) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *RowDescription) Encode(dst []byte) []byte {
|
func (src *RowDescription) Encode(dst []byte) ([]byte, error) {
|
||||||
dst = append(dst, 'T')
|
dst, sp := beginMessage(dst, 'T')
|
||||||
sp := len(dst)
|
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
|
||||||
|
|
||||||
|
if len(src.Fields) > math.MaxUint16 {
|
||||||
|
return nil, errors.New("too many fields")
|
||||||
|
}
|
||||||
dst = pgio.AppendUint16(dst, uint16(len(src.Fields)))
|
dst = pgio.AppendUint16(dst, uint16(len(src.Fields)))
|
||||||
for _, fd := range src.Fields {
|
for _, fd := range src.Fields {
|
||||||
dst = append(dst, fd.Name...)
|
dst = append(dst, fd.Name...)
|
||||||
@@ -117,9 +120,7 @@ func (src *RowDescription) Encode(dst []byte) []byte {
|
|||||||
dst = pgio.AppendInt16(dst, fd.Format)
|
dst = pgio.AppendInt16(dst, fd.Format)
|
||||||
}
|
}
|
||||||
|
|
||||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
return finishMessage(dst, sp)
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
@@ -39,10 +39,8 @@ func (dst *SASLInitialResponse) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *SASLInitialResponse) Encode(dst []byte) []byte {
|
func (src *SASLInitialResponse) Encode(dst []byte) ([]byte, error) {
|
||||||
dst = append(dst, 'p')
|
dst, sp := beginMessage(dst, 'p')
|
||||||
sp := len(dst)
|
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
|
||||||
|
|
||||||
dst = append(dst, []byte(src.AuthMechanism)...)
|
dst = append(dst, []byte(src.AuthMechanism)...)
|
||||||
dst = append(dst, 0)
|
dst = append(dst, 0)
|
||||||
@@ -50,9 +48,7 @@ func (src *SASLInitialResponse) Encode(dst []byte) []byte {
|
|||||||
dst = pgio.AppendInt32(dst, int32(len(src.Data)))
|
dst = pgio.AppendInt32(dst, int32(len(src.Data)))
|
||||||
dst = append(dst, src.Data...)
|
dst = append(dst, src.Data...)
|
||||||
|
|
||||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
return finishMessage(dst, sp)
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
@@ -3,8 +3,6 @@ package pgproto3
|
|||||||
import (
|
import (
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
|
||||||
"github.com/andoma-go/pgx/v5/internal/pgio"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type SASLResponse struct {
|
type SASLResponse struct {
|
||||||
@@ -22,13 +20,10 @@ func (dst *SASLResponse) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *SASLResponse) Encode(dst []byte) []byte {
|
func (src *SASLResponse) Encode(dst []byte) ([]byte, error) {
|
||||||
dst = append(dst, 'p')
|
dst, sp := beginMessage(dst, 'p')
|
||||||
dst = pgio.AppendInt32(dst, int32(4+len(src.Data)))
|
|
||||||
|
|
||||||
dst = append(dst, src.Data...)
|
dst = append(dst, src.Data...)
|
||||||
|
return finishMessage(dst, sp)
|
||||||
return dst
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
@@ -31,10 +31,10 @@ func (dst *SSLRequest) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 4 byte message length.
|
||||||
func (src *SSLRequest) Encode(dst []byte) []byte {
|
func (src *SSLRequest) Encode(dst []byte) ([]byte, error) {
|
||||||
dst = pgio.AppendInt32(dst, 8)
|
dst = pgio.AppendInt32(dst, 8)
|
||||||
dst = pgio.AppendInt32(dst, sslRequestNumber)
|
dst = pgio.AppendInt32(dst, sslRequestNumber)
|
||||||
return dst
|
return dst, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ func (dst *StartupMessage) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *StartupMessage) Encode(dst []byte) []byte {
|
func (src *StartupMessage) Encode(dst []byte) ([]byte, error) {
|
||||||
sp := len(dst)
|
sp := len(dst)
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
|
||||||
@@ -77,9 +77,7 @@ func (src *StartupMessage) Encode(dst []byte) []byte {
|
|||||||
}
|
}
|
||||||
dst = append(dst, 0)
|
dst = append(dst, 0)
|
||||||
|
|
||||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
return finishMessage(dst, sp)
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
+2
-2
@@ -20,8 +20,8 @@ func (dst *Sync) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *Sync) Encode(dst []byte) []byte {
|
func (src *Sync) Encode(dst []byte) ([]byte, error) {
|
||||||
return append(dst, 'S', 0, 0, 0, 4)
|
return append(dst, 'S', 0, 0, 0, 4), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
@@ -20,8 +20,8 @@ func (dst *Terminate) Decode(src []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
func (src *Terminate) Encode(dst []byte) []byte {
|
func (src *Terminate) Encode(dst []byte) ([]byte, error) {
|
||||||
return append(dst, 'X', 0, 0, 0, 4)
|
return append(dst, 'X', 0, 0, 0, 4), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
|||||||
+3
-1
@@ -176,8 +176,10 @@ func (scanPlanBinaryBitsToBitsScanner) Scan(src []byte, dst any) error {
|
|||||||
|
|
||||||
bitLen := int32(binary.BigEndian.Uint32(src))
|
bitLen := int32(binary.BigEndian.Uint32(src))
|
||||||
rp := 4
|
rp := 4
|
||||||
|
buf := make([]byte, len(src[rp:]))
|
||||||
|
copy(buf, src[rp:])
|
||||||
|
|
||||||
return scanner.ScanBits(Bits{Bytes: src[rp:], Len: bitLen, Valid: true})
|
return scanner.ScanBits(Bits{Bytes: buf, Len: bitLen, Valid: true})
|
||||||
}
|
}
|
||||||
|
|
||||||
type scanPlanTextAnyToBitsScanner struct{}
|
type scanPlanTextAnyToBitsScanner struct{}
|
||||||
|
|||||||
+2
-2
@@ -297,12 +297,12 @@ func (c Float4Codec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, sr
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var n float64
|
var n float32
|
||||||
err := codecScan(c, m, oid, format, src, &n)
|
err := codecScan(c, m, oid, format, src, &n)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return n, nil
|
return float64(n), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c Float4Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
|
func (c Float4Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
|
||||||
|
|||||||
@@ -25,6 +25,11 @@ func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Encod
|
|||||||
case []byte:
|
case []byte:
|
||||||
return encodePlanJSONCodecEitherFormatByteSlice{}
|
return encodePlanJSONCodecEitherFormatByteSlice{}
|
||||||
|
|
||||||
|
// Handle json.RawMessage specifically because if it is run through json.Marshal it may be mutated.
|
||||||
|
// e.g. `{"foo": "bar"}` -> `{"foo":"bar"}`.
|
||||||
|
case json.RawMessage:
|
||||||
|
return encodePlanJSONCodecEitherFormatJSONRawMessage{}
|
||||||
|
|
||||||
// Cannot rely on driver.Valuer being handled later because anything can be marshalled.
|
// Cannot rely on driver.Valuer being handled later because anything can be marshalled.
|
||||||
//
|
//
|
||||||
// https://github.com/jackc/pgx/issues/1430
|
// https://github.com/jackc/pgx/issues/1430
|
||||||
@@ -79,6 +84,18 @@ func (encodePlanJSONCodecEitherFormatByteSlice) Encode(value any, buf []byte) (n
|
|||||||
return buf, nil
|
return buf, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type encodePlanJSONCodecEitherFormatJSONRawMessage struct{}
|
||||||
|
|
||||||
|
func (encodePlanJSONCodecEitherFormatJSONRawMessage) Encode(value any, buf []byte) (newBuf []byte, err error) {
|
||||||
|
jsonBytes := value.(json.RawMessage)
|
||||||
|
if jsonBytes == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
buf = append(buf, jsonBytes...)
|
||||||
|
return buf, nil
|
||||||
|
}
|
||||||
|
|
||||||
type encodePlanJSONCodecEitherFormatMarshal struct{}
|
type encodePlanJSONCodecEitherFormatMarshal struct{}
|
||||||
|
|
||||||
func (encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (newBuf []byte, err error) {
|
func (encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (newBuf []byte, err error) {
|
||||||
|
|||||||
+122
@@ -0,0 +1,122 @@
|
|||||||
|
package pgtype
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type LtreeCodec struct{}
|
||||||
|
|
||||||
|
func (l LtreeCodec) FormatSupported(format int16) bool {
|
||||||
|
return format == TextFormatCode || format == BinaryFormatCode
|
||||||
|
}
|
||||||
|
|
||||||
|
// PreferredFormat returns the preferred format.
|
||||||
|
func (l LtreeCodec) PreferredFormat() int16 {
|
||||||
|
return TextFormatCode
|
||||||
|
}
|
||||||
|
|
||||||
|
// PlanEncode returns an EncodePlan for encoding value into PostgreSQL format for oid and format. If no plan can be
|
||||||
|
// found then nil is returned.
|
||||||
|
func (l LtreeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
|
||||||
|
switch format {
|
||||||
|
case TextFormatCode:
|
||||||
|
return (TextCodec)(l).PlanEncode(m, oid, format, value)
|
||||||
|
case BinaryFormatCode:
|
||||||
|
switch value.(type) {
|
||||||
|
case string:
|
||||||
|
return encodeLtreeCodecBinaryString{}
|
||||||
|
case []byte:
|
||||||
|
return encodeLtreeCodecBinaryByteSlice{}
|
||||||
|
case TextValuer:
|
||||||
|
return encodeLtreeCodecBinaryTextValuer{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type encodeLtreeCodecBinaryString struct{}
|
||||||
|
|
||||||
|
func (encodeLtreeCodecBinaryString) Encode(value any, buf []byte) (newBuf []byte, err error) {
|
||||||
|
ltree := value.(string)
|
||||||
|
buf = append(buf, 1)
|
||||||
|
return append(buf, ltree...), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type encodeLtreeCodecBinaryByteSlice struct{}
|
||||||
|
|
||||||
|
func (encodeLtreeCodecBinaryByteSlice) Encode(value any, buf []byte) (newBuf []byte, err error) {
|
||||||
|
ltree := value.([]byte)
|
||||||
|
buf = append(buf, 1)
|
||||||
|
return append(buf, ltree...), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type encodeLtreeCodecBinaryTextValuer struct{}
|
||||||
|
|
||||||
|
func (encodeLtreeCodecBinaryTextValuer) Encode(value any, buf []byte) (newBuf []byte, err error) {
|
||||||
|
t, err := value.(TextValuer).TextValue()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !t.Valid {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
buf = append(buf, 1)
|
||||||
|
return append(buf, t.String...), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PlanScan returns a ScanPlan for scanning a PostgreSQL value into a destination with the same type as target. If
|
||||||
|
// no plan can be found then nil is returned.
|
||||||
|
func (l LtreeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
||||||
|
switch format {
|
||||||
|
case TextFormatCode:
|
||||||
|
return (TextCodec)(l).PlanScan(m, oid, format, target)
|
||||||
|
case BinaryFormatCode:
|
||||||
|
switch target.(type) {
|
||||||
|
case *string:
|
||||||
|
return scanPlanBinaryLtreeToString{}
|
||||||
|
case TextScanner:
|
||||||
|
return scanPlanBinaryLtreeToTextScanner{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type scanPlanBinaryLtreeToString struct{}
|
||||||
|
|
||||||
|
func (scanPlanBinaryLtreeToString) Scan(src []byte, target any) error {
|
||||||
|
version := src[0]
|
||||||
|
if version != 1 {
|
||||||
|
return fmt.Errorf("unsupported ltree version %d", version)
|
||||||
|
}
|
||||||
|
|
||||||
|
p := (target).(*string)
|
||||||
|
*p = string(src[1:])
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type scanPlanBinaryLtreeToTextScanner struct{}
|
||||||
|
|
||||||
|
func (scanPlanBinaryLtreeToTextScanner) Scan(src []byte, target any) error {
|
||||||
|
version := src[0]
|
||||||
|
if version != 1 {
|
||||||
|
return fmt.Errorf("unsupported ltree version %d", version)
|
||||||
|
}
|
||||||
|
|
||||||
|
scanner := (target).(TextScanner)
|
||||||
|
return scanner.ScanText(Text{String: string(src[1:]), Valid: true})
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeDatabaseSQLValue returns src decoded into a value compatible with the sql.Scanner interface.
|
||||||
|
func (l LtreeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
|
||||||
|
return (TextCodec)(l).DecodeDatabaseSQLValue(m, oid, format, src)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeValue returns src decoded into its default format.
|
||||||
|
func (l LtreeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
|
||||||
|
return (TextCodec)(l).DecodeValue(m, oid, format, src)
|
||||||
|
}
|
||||||
@@ -0,0 +1,26 @@
|
|||||||
|
package pgtype_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/andoma-go/pgx/v5/pgtype"
|
||||||
|
"github.com/andoma-go/pgx/v5/pgxtest"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLtreeCodec(t *testing.T) {
|
||||||
|
skipCockroachDB(t, "Server does not support type ltree")
|
||||||
|
|
||||||
|
pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "ltree", []pgxtest.ValueRoundTripTest{
|
||||||
|
{
|
||||||
|
Param: "A.B.C",
|
||||||
|
Result: new(string),
|
||||||
|
Test: isExpectedEq("A.B.C"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Param: pgtype.Text{String: "", Valid: true},
|
||||||
|
Result: new(pgtype.Text),
|
||||||
|
Test: isExpectedEq(pgtype.Text{String: "", Valid: true}),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -48,4 +48,23 @@ func TestMacaddrCodec(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{nil, new(*net.HardwareAddr), isExpectedEq((*net.HardwareAddr)(nil))},
|
{nil, new(*net.HardwareAddr), isExpectedEq((*net.HardwareAddr)(nil))},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "macaddr8", []pgxtest.ValueRoundTripTest{
|
||||||
|
{
|
||||||
|
mustParseMacaddr(t, "01:23:45:67:89:ab:01:08"),
|
||||||
|
new(net.HardwareAddr),
|
||||||
|
isExpectedEqHardwareAddr(mustParseMacaddr(t, "01:23:45:67:89:ab:01:08")),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"01:23:45:67:89:ab:01:08",
|
||||||
|
new(net.HardwareAddr),
|
||||||
|
isExpectedEqHardwareAddr(mustParseMacaddr(t, "01:23:45:67:89:ab:01:08")),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
mustParseMacaddr(t, "01:23:45:67:89:ab:01:08"),
|
||||||
|
new(string),
|
||||||
|
isExpectedEq("01:23:45:67:89:ab:01:08"),
|
||||||
|
},
|
||||||
|
{nil, new(*net.HardwareAddr), isExpectedEq((*net.HardwareAddr)(nil))},
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
+4
-1
@@ -41,6 +41,7 @@ const (
|
|||||||
CircleOID = 718
|
CircleOID = 718
|
||||||
CircleArrayOID = 719
|
CircleArrayOID = 719
|
||||||
UnknownOID = 705
|
UnknownOID = 705
|
||||||
|
Macaddr8OID = 774
|
||||||
MacaddrOID = 829
|
MacaddrOID = 829
|
||||||
InetOID = 869
|
InetOID = 869
|
||||||
BoolArrayOID = 1000
|
BoolArrayOID = 1000
|
||||||
@@ -81,6 +82,8 @@ const (
|
|||||||
IntervalOID = 1186
|
IntervalOID = 1186
|
||||||
IntervalArrayOID = 1187
|
IntervalArrayOID = 1187
|
||||||
NumericArrayOID = 1231
|
NumericArrayOID = 1231
|
||||||
|
TimetzOID = 1266
|
||||||
|
TimetzArrayOID = 1270
|
||||||
BitOID = 1560
|
BitOID = 1560
|
||||||
BitArrayOID = 1561
|
BitArrayOID = 1561
|
||||||
VarbitOID = 1562
|
VarbitOID = 1562
|
||||||
@@ -559,7 +562,7 @@ func TryFindUnderlyingTypeScanPlan(dst any) (plan WrappedScanPlanNextSetter, nex
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if nextDstType != nil && dstValue.Type() != nextDstType {
|
if nextDstType != nil && dstValue.Type() != nextDstType && dstValue.CanConvert(nextDstType) {
|
||||||
return &underlyingTypeScanPlan{dstType: dstValue.Type(), nextDstType: nextDstType}, dstValue.Convert(nextDstType).Interface(), true
|
return &underlyingTypeScanPlan{dstType: dstValue.Type(), nextDstType: nextDstType}, dstValue.Convert(nextDstType).Interface(), true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package pgtype
|
package pgtype
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"reflect"
|
"reflect"
|
||||||
@@ -69,6 +70,7 @@ func initDefaultMap() {
|
|||||||
defaultMap.RegisterType(&Type{Name: "jsonpath", OID: JSONPathOID, Codec: &TextFormatOnlyCodec{TextCodec{}}})
|
defaultMap.RegisterType(&Type{Name: "jsonpath", OID: JSONPathOID, Codec: &TextFormatOnlyCodec{TextCodec{}}})
|
||||||
defaultMap.RegisterType(&Type{Name: "line", OID: LineOID, Codec: LineCodec{}})
|
defaultMap.RegisterType(&Type{Name: "line", OID: LineOID, Codec: LineCodec{}})
|
||||||
defaultMap.RegisterType(&Type{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}})
|
defaultMap.RegisterType(&Type{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}})
|
||||||
|
defaultMap.RegisterType(&Type{Name: "macaddr8", OID: Macaddr8OID, Codec: MacaddrCodec{}})
|
||||||
defaultMap.RegisterType(&Type{Name: "macaddr", OID: MacaddrOID, Codec: MacaddrCodec{}})
|
defaultMap.RegisterType(&Type{Name: "macaddr", OID: MacaddrOID, Codec: MacaddrCodec{}})
|
||||||
defaultMap.RegisterType(&Type{Name: "name", OID: NameOID, Codec: TextCodec{}})
|
defaultMap.RegisterType(&Type{Name: "name", OID: NameOID, Codec: TextCodec{}})
|
||||||
defaultMap.RegisterType(&Type{Name: "numeric", OID: NumericOID, Codec: NumericCodec{}})
|
defaultMap.RegisterType(&Type{Name: "numeric", OID: NumericOID, Codec: NumericCodec{}})
|
||||||
@@ -173,6 +175,7 @@ func initDefaultMap() {
|
|||||||
registerDefaultPgTypeVariants[time.Time](defaultMap, "timestamptz")
|
registerDefaultPgTypeVariants[time.Time](defaultMap, "timestamptz")
|
||||||
registerDefaultPgTypeVariants[time.Duration](defaultMap, "interval")
|
registerDefaultPgTypeVariants[time.Duration](defaultMap, "interval")
|
||||||
registerDefaultPgTypeVariants[string](defaultMap, "text")
|
registerDefaultPgTypeVariants[string](defaultMap, "text")
|
||||||
|
registerDefaultPgTypeVariants[json.RawMessage](defaultMap, "json")
|
||||||
registerDefaultPgTypeVariants[[]byte](defaultMap, "bytea")
|
registerDefaultPgTypeVariants[[]byte](defaultMap, "bytea")
|
||||||
|
|
||||||
registerDefaultPgTypeVariants[net.IP](defaultMap, "inet")
|
registerDefaultPgTypeVariants[net.IP](defaultMap, "inet")
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ func init() {
|
|||||||
// Test for renamed types
|
// Test for renamed types
|
||||||
type _string string
|
type _string string
|
||||||
type _bool bool
|
type _bool bool
|
||||||
|
type _uint8 uint8
|
||||||
type _int8 int8
|
type _int8 int8
|
||||||
type _int16 int16
|
type _int16 int16
|
||||||
type _int16Slice []int16
|
type _int16Slice []int16
|
||||||
@@ -453,6 +454,14 @@ func TestMapScanNullToWrongType(t *testing.T) {
|
|||||||
assert.False(t, pn.Valid)
|
assert.False(t, pn.Valid)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestScanToSliceOfRenamedUint8(t *testing.T) {
|
||||||
|
m := pgtype.NewMap()
|
||||||
|
var ruint8 []_uint8
|
||||||
|
err := m.Scan(pgtype.Int2ArrayOID, pgx.TextFormatCode, []byte("{2,4}"), &ruint8)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, []_uint8{2, 4}, ruint8)
|
||||||
|
}
|
||||||
|
|
||||||
func TestMapScanTextToBool(t *testing.T) {
|
func TestMapScanTextToBool(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -537,6 +546,14 @@ func TestMapEncodePlanCacheUUIDTypeConfusion(t *testing.T) {
|
|||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// https://github.com/jackc/pgx/issues/1763
|
||||||
|
func TestMapEncodeRawJSONIntoUnknownOID(t *testing.T) {
|
||||||
|
m := pgtype.NewMap()
|
||||||
|
buf, err := m.Encode(0, pgtype.TextFormatCode, json.RawMessage(`{"foo": "bar"}`), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, []byte(`{"foo": "bar"}`), buf)
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkMapScanInt4IntoBinaryDecoder(b *testing.B) {
|
func BenchmarkMapScanInt4IntoBinaryDecoder(b *testing.B) {
|
||||||
m := pgtype.NewMap()
|
m := pgtype.NewMap()
|
||||||
src := []byte{0, 0, 0, 42}
|
src := []byte{0, 0, 0, 42}
|
||||||
|
|||||||
+13
-1
@@ -52,7 +52,19 @@ func parseUUID(src string) (dst [16]byte, err error) {
|
|||||||
|
|
||||||
// encodeUUID converts a uuid byte array to UUID standard string form.
|
// encodeUUID converts a uuid byte array to UUID standard string form.
|
||||||
func encodeUUID(src [16]byte) string {
|
func encodeUUID(src [16]byte) string {
|
||||||
return fmt.Sprintf("%x-%x-%x-%x-%x", src[0:4], src[4:6], src[6:8], src[8:10], src[10:16])
|
var buf [36]byte
|
||||||
|
|
||||||
|
hex.Encode(buf[0:8], src[:4])
|
||||||
|
buf[8] = '-'
|
||||||
|
hex.Encode(buf[9:13], src[4:6])
|
||||||
|
buf[13] = '-'
|
||||||
|
hex.Encode(buf[14:18], src[6:8])
|
||||||
|
buf[18] = '-'
|
||||||
|
hex.Encode(buf[19:23], src[8:10])
|
||||||
|
buf[23] = '-'
|
||||||
|
hex.Encode(buf[24:], src[10:])
|
||||||
|
|
||||||
|
return string(buf[:])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan implements the database/sql Scanner interface.
|
// Scan implements the database/sql Scanner interface.
|
||||||
|
|||||||
@@ -417,12 +417,10 @@ type CollectableRow interface {
|
|||||||
// RowToFunc is a function that scans or otherwise converts row to a T.
|
// RowToFunc is a function that scans or otherwise converts row to a T.
|
||||||
type RowToFunc[T any] func(row CollectableRow) (T, error)
|
type RowToFunc[T any] func(row CollectableRow) (T, error)
|
||||||
|
|
||||||
// CollectRows iterates through rows, calling fn for each row, and collecting the results into a slice of T.
|
// AppendRows iterates through rows, calling fn for each row, and appending the results into a slice of T.
|
||||||
func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) {
|
func AppendRows[T any, S ~[]T](slice S, rows Rows, fn RowToFunc[T]) (S, error) {
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
slice := []T{}
|
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
value, err := fn(rows)
|
value, err := fn(rows)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -438,6 +436,11 @@ func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) {
|
|||||||
return slice, nil
|
return slice, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CollectRows iterates through rows, calling fn for each row, and collecting the results into a slice of T.
|
||||||
|
func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) {
|
||||||
|
return AppendRows([]T{}, rows, fn)
|
||||||
|
}
|
||||||
|
|
||||||
// CollectOneRow calls fn for the first row in rows and returns the result. If no rows are found returns an error where errors.Is(ErrNoRows) is true.
|
// CollectOneRow calls fn for the first row in rows and returns the result. If no rows are found returns an error where errors.Is(ErrNoRows) is true.
|
||||||
// CollectOneRow is to CollectRows as QueryRow is to Query.
|
// CollectOneRow is to CollectRows as QueryRow is to Query.
|
||||||
func CollectOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) {
|
func CollectOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) {
|
||||||
|
|||||||
@@ -175,6 +175,21 @@ func TestCollectRows(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCollectRowsEmpty(t *testing.T) {
|
||||||
|
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
rows, _ := conn.Query(ctx, `select n from generate_series(1, 0) n`)
|
||||||
|
numbers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (int32, error) {
|
||||||
|
var n int32
|
||||||
|
err := row.Scan(&n)
|
||||||
|
return n, err
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, numbers)
|
||||||
|
|
||||||
|
assert.Empty(t, numbers)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// This example uses CollectRows with a manually written collector function. In most cases RowTo, RowToAddrOf,
|
// This example uses CollectRows with a manually written collector function. In most cases RowTo, RowToAddrOf,
|
||||||
// RowToStructByPos, RowToAddrOfStructByPos, or another generic function would be used.
|
// RowToStructByPos, RowToAddrOfStructByPos, or another generic function would be used.
|
||||||
func ExampleCollectRows() {
|
func ExampleCollectRows() {
|
||||||
|
|||||||
@@ -1,6 +0,0 @@
|
|||||||
[ req ]
|
|
||||||
distinguished_name = dn
|
|
||||||
[ dn ]
|
|
||||||
commonName = ca
|
|
||||||
[ ext ]
|
|
||||||
basicConstraints =CA:TRUE,pathlen:0
|
|
||||||
@@ -0,0 +1,187 @@
|
|||||||
|
// Generates a CA, server certificate, and encrypted client certificate for testing pgx.
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"encoding/pem"
|
||||||
|
"fmt"
|
||||||
|
"math/big"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Create the CA
|
||||||
|
ca := &x509.Certificate{
|
||||||
|
SerialNumber: big.NewInt(1),
|
||||||
|
Subject: pkix.Name{
|
||||||
|
CommonName: "pgx-root-ca",
|
||||||
|
},
|
||||||
|
NotBefore: time.Now(),
|
||||||
|
NotAfter: time.Now().AddDate(20, 0, 0),
|
||||||
|
IsCA: true,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
|
||||||
|
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
||||||
|
BasicConstraintsValid: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
caKey, err := rsa.GenerateKey(rand.Reader, 4096)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caKey.PublicKey, caKey)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = writePrivateKey("ca.key", caKey)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = writeCertificate("ca.pem", caBytes)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a server certificate signed by the CA for localhost.
|
||||||
|
serverCert := &x509.Certificate{
|
||||||
|
SerialNumber: big.NewInt(2),
|
||||||
|
Subject: pkix.Name{
|
||||||
|
CommonName: "localhost",
|
||||||
|
},
|
||||||
|
DNSNames: []string{"localhost"},
|
||||||
|
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback},
|
||||||
|
NotBefore: time.Now(),
|
||||||
|
NotAfter: time.Now().AddDate(20, 0, 0),
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
|
||||||
|
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||||
|
}
|
||||||
|
|
||||||
|
serverCertPrivKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
serverBytes, err := x509.CreateCertificate(rand.Reader, serverCert, ca, &serverCertPrivKey.PublicKey, caKey)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = writePrivateKey("localhost.key", serverCertPrivKey)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = writeCertificate("localhost.crt", serverBytes)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a client certificate signed by the CA and encrypted.
|
||||||
|
clientCert := &x509.Certificate{
|
||||||
|
SerialNumber: big.NewInt(3),
|
||||||
|
Subject: pkix.Name{
|
||||||
|
CommonName: "pgx_sslcert",
|
||||||
|
},
|
||||||
|
NotBefore: time.Now(),
|
||||||
|
NotAfter: time.Now().AddDate(20, 0, 0),
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
|
||||||
|
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||||
|
}
|
||||||
|
|
||||||
|
clientCertPrivKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
clientBytes, err := x509.CreateCertificate(rand.Reader, clientCert, ca, &clientCertPrivKey.PublicKey, caKey)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
writeEncryptedPrivateKey("pgx_sslcert.key", clientCertPrivKey, "certpw")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
writeCertificate("pgx_sslcert.crt", clientBytes)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func writePrivateKey(path string, privateKey *rsa.PrivateKey) error {
|
||||||
|
file, err := os.Create(path)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("writePrivateKey: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = pem.Encode(file, &pem.Block{
|
||||||
|
Type: "RSA PRIVATE KEY",
|
||||||
|
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("writePrivateKey: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = file.Close()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("writePrivateKey: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeEncryptedPrivateKey(path string, privateKey *rsa.PrivateKey, password string) error {
|
||||||
|
file, err := os.Create(path)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("writeEncryptedPrivateKey: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
block, err := x509.EncryptPEMBlock(rand.Reader, "CERTIFICATE", x509.MarshalPKCS1PrivateKey(privateKey), []byte(password), x509.PEMCipher3DES)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("writeEncryptedPrivateKey: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = pem.Encode(file, block)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("writeEncryptedPrivateKey: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = file.Close()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("writeEncryptedPrivateKey: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeCertificate(path string, certBytes []byte) error {
|
||||||
|
file, err := os.Create(path)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("writeCertificate: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = pem.Encode(file, &pem.Block{
|
||||||
|
Type: "CERTIFICATE",
|
||||||
|
Bytes: certBytes,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("writeCertificate: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = file.Close()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("writeCertificate: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
[ req ]
|
|
||||||
default_bits = 2048
|
|
||||||
distinguished_name = dn
|
|
||||||
req_extensions = v3_req
|
|
||||||
prompt = no
|
|
||||||
[ dn ]
|
|
||||||
commonName = localhost
|
|
||||||
[ v3_req ]
|
|
||||||
subjectAltName = @alt_names
|
|
||||||
keyUsage = digitalSignature
|
|
||||||
extendedKeyUsage = serverAuth
|
|
||||||
[alt_names]
|
|
||||||
DNS.1 = localhost
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
[ req ]
|
|
||||||
default_bits = 2048
|
|
||||||
distinguished_name = dn
|
|
||||||
req_extensions = v3_req
|
|
||||||
prompt = no
|
|
||||||
[ dn ]
|
|
||||||
commonName = pgx_sslcert
|
|
||||||
[ v3_req ]
|
|
||||||
keyUsage = digitalSignature
|
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
-- Create extensions and types.
|
-- Create extensions and types.
|
||||||
create extension hstore;
|
create extension hstore;
|
||||||
|
create extension ltree;
|
||||||
create domain uint64 as numeric(20,0);
|
create domain uint64 as numeric(20,0);
|
||||||
|
|
||||||
-- Create users for different types of connections and authentication.
|
-- Create users for different types of connections and authentication.
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user